Compare commits

..

1 Commits

Author SHA1 Message Date
87bf75e51c add Registration closed note 2023-07-28 12:36:15 -05:00
572 changed files with 10135 additions and 21196 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2023.8.1 current_version = 2023.6.1
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@ -7,4 +7,3 @@ build/**
build_docs/** build_docs/**
Dockerfile Dockerfile
authentik/enterprise authentik/enterprise
blueprints/local

View File

@ -14,7 +14,7 @@ runs:
run: | run: |
pipx install poetry || true pipx install poetry || true
sudo apt update sudo apt update
sudo apt install -y libpq-dev openssl libxmlsec1-dev pkg-config gettext sudo apt install -y libxmlsec1-dev pkg-config gettext
- name: Setup python and restore poetry - name: Setup python and restore poetry
uses: actions/setup-python@v3 uses: actions/setup-python@v3
with: with:

View File

@ -1,2 +0,0 @@
enabled: true
preservePullRequestTitle: true

View File

@ -8,8 +8,6 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "ci:" prefix: "ci:"
labels:
- dependencies
- package-ecosystem: gomod - package-ecosystem: gomod
directory: "/" directory: "/"
schedule: schedule:
@ -18,15 +16,11 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "core:" prefix: "core:"
labels:
- dependencies
- package-ecosystem: npm - package-ecosystem: npm
directory: "/web" directory: "/web"
schedule: schedule:
interval: daily interval: daily
time: "04:00" time: "04:00"
labels:
- dependencies
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "web:" prefix: "web:"
@ -38,18 +32,10 @@ updates:
patterns: patterns:
- "@babel/*" - "@babel/*"
- "babel-*" - "babel-*"
eslint:
patterns:
- "@typescript-eslint/eslint-*"
- "eslint"
- "eslint-*"
storybook: storybook:
patterns: patterns:
- "@storybook/*" - "@storybook/*"
- "*storybook*" - "*storybook*"
esbuild:
patterns:
- "@esbuild/*"
- package-ecosystem: npm - package-ecosystem: npm
directory: "/website" directory: "/website"
schedule: schedule:
@ -58,8 +44,6 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "website:" prefix: "website:"
labels:
- dependencies
groups: groups:
docusaurus: docusaurus:
patterns: patterns:
@ -72,8 +56,6 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "core:" prefix: "core:"
labels:
- dependencies
- package-ecosystem: docker - package-ecosystem: docker
directory: "/" directory: "/"
schedule: schedule:
@ -82,5 +64,3 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "core:" prefix: "core:"
labels:
- dependencies

View File

@ -1,19 +1,23 @@
<!-- <!--
👋 Hi there! Welcome. 👋 Hello there! Welcome.
Please check the Contributing guidelines: https://goauthentik.io/developer-docs/#how-can-i-contribute Please check the [Contributing guidelines](https://goauthentik.io/developer-docs/#how-can-i-contribute).
--> -->
## Details ## Details
<!-- - **Does this resolve an issue?**
Explain what this PR changes, what the rationale behind the change is, if any new requirements are introduced or any breaking changes caused by this PR. Resolves #
Ideally also link an Issue for context that this PR will close using `closes #` ## Changes
-->
REPLACE ME
--- ### New Features
- Adds feature which does x, y, and z.
### Breaking Changes
- Adds breaking change which causes \<issue\>.
## Checklist ## Checklist

View File

@ -88,8 +88,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
psql: psql:
- 11-alpine
- 12-alpine - 12-alpine
- 15-alpine
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup authentik env - name: Setup authentik env

View File

@ -120,7 +120,7 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

View File

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -31,7 +31,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -47,7 +47,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -63,7 +63,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -95,7 +95,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

View File

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -29,7 +29,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -50,7 +50,7 @@ jobs:
- build-docs-only - build-docs-only
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

View File

@ -1,34 +0,0 @@
---
# See https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries
name: Cleanup cache after PR is closed
on:
pull_request:
types:
- closed
jobs:
cleanup:
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v3
- name: Cleanup
run: |
gh extension install actions/gh-actions-cache
REPO=${{ github.repository }}
BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge"
echo "Fetching list of cache key"
cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH -L 100 | cut -f 1 )
# Setting this to not fail the workflow while deleting cache keys.
set +e
echo "Deleting caches..."
for cacheKey in $cacheKeysForPR; do
gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm
done
echo "Done"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,61 +0,0 @@
---
name: authentik-compress-images
on:
push:
branches:
- main
paths:
- "**.jpg"
- "**.jpeg"
- "**.png"
- "**.webp"
pull_request:
paths:
- "**.jpg"
- "**.jpeg"
- "**.png"
- "**.webp"
workflow_dispatch:
jobs:
compress:
name: compress
runs-on: ubuntu-latest
# Don't run on forks. Token will not be available. Will run on main and open a PR anyway
if: |
github.repository == 'goauthentik/authentik' &&
(github.event_name != 'pull_request' ||
github.event.pull_request.head.repo.full_name == github.repository)
steps:
- id: generate_token
uses: tibdex/github-app-token@v1
with:
app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
- uses: actions/checkout@v3
with:
token: ${{ steps.generate_token.outputs.token }}
- name: Compress images
id: compress
uses: calibreapp/image-actions@main
with:
githubToken: ${{ steps.generate_token.outputs.token }}
compressOnly: ${{ github.event_name != 'pull_request' }}
- uses: peter-evans/create-pull-request@v5
if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}"
id: cpr
with:
token: ${{ steps.generate_token.outputs.token }}
title: "*: Auto compress images"
branch-suffix: timestamp
commit-messsage: "*: compress images"
body: ${{ steps.compress.outputs.markdown }}
delete-branch: true
signoff: true
- uses: peter-evans/enable-pull-request-automerge@v3
if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}"
with:
token: ${{ steps.generate_token.outputs.token }}
pull-request-number: ${{ steps.cpr.outputs.pull-request-number }}
merge-method: squash

View File

@ -1,31 +0,0 @@
name: authentik-publish-source-docs
on:
push:
branches:
- main
env:
POSTGRES_DB: authentik
POSTGRES_USER: authentik
POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77"
jobs:
publish-source-docs:
runs-on: ubuntu-latest
timeout-minutes: 120
steps:
- uses: actions/checkout@v3
- name: Setup authentik env
uses: ./.github/actions/setup
- name: generate docs
run: |
poetry run make migrate
poetry run ak build_source_docs
- name: Publish
uses: netlify/actions/cli@master
with:
args: deploy --dir=source_docs --prod
env:
NETLIFY_SITE_ID: eb246b7b-1d83-4f69-89f7-01a936b4ca59
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}

View File

@ -110,7 +110,7 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

View File

@ -1,5 +1,4 @@
# Rename transifex pull requests to have a correct naming # Rename transifex pull requests to have a correct naming
# Also enables auto squash-merge
name: authentik-translation-transifex-rename name: authentik-translation-transifex-rename
on: on:
@ -38,8 +37,3 @@ jobs:
-H "X-GitHub-Api-Version: 2022-11-28" \ -H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} \ https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} \
-d "{\"title\":\"translate: ${{ steps.title.outputs.title }}\"}" -d "{\"title\":\"translate: ${{ steps.title.outputs.title }}\"}"
- uses: peter-evans/enable-pull-request-automerge@v3
with:
token: ${{ steps.generate_token.outputs.token }}
pull-request-number: ${{ github.event.pull_request.number }}
merge-method: squash

View File

@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
registry-url: "https://registry.npmjs.org" registry-url: "https://registry.npmjs.org"

1
.gitignore vendored
View File

@ -205,4 +205,3 @@ data/
# Local Netlify folder # Local Netlify folder
.netlify .netlify
.ruff_cache .ruff_cache
source_docs/

View File

@ -31,8 +31,7 @@
"!Format sequence", "!Format sequence",
"!Condition sequence", "!Condition sequence",
"!Env sequence", "!Env sequence",
"!Env scalar", "!Env scalar"
"!If sequence"
], ],
"typescript.preferences.importModuleSpecifier": "non-relative", "typescript.preferences.importModuleSpecifier": "non-relative",
"typescript.preferences.importModuleSpecifierEnding": "index", "typescript.preferences.importModuleSpecifierEnding": "index",

View File

@ -20,7 +20,7 @@ WORKDIR /work/web
RUN npm ci --include=dev && npm run build RUN npm ci --include=dev && npm run build
# Stage 3: Poetry to requirements.txt export # Stage 3: Poetry to requirements.txt export
FROM docker.io/python:3.11.5-slim-bookworm AS poetry-locker FROM docker.io/python:3.11.4-slim-bullseye AS poetry-locker
WORKDIR /work WORKDIR /work
COPY ./pyproject.toml /work COPY ./pyproject.toml /work
@ -31,7 +31,7 @@ RUN pip install --no-cache-dir poetry && \
poetry export -f requirements.txt --dev --output requirements-dev.txt poetry export -f requirements.txt --dev --output requirements-dev.txt
# Stage 4: Build go proxy # Stage 4: Build go proxy
FROM docker.io/golang:1.21.0-bookworm AS go-builder FROM docker.io/golang:1.20.6-bullseye AS go-builder
WORKDIR /work WORKDIR /work
@ -39,13 +39,12 @@ COPY --from=web-builder /work/web/robots.txt /work/web/robots.txt
COPY --from=web-builder /work/web/security.txt /work/web/security.txt COPY --from=web-builder /work/web/security.txt /work/web/security.txt
COPY ./cmd /work/cmd COPY ./cmd /work/cmd
COPY ./authentik/lib /work/authentik/lib
COPY ./web/static.go /work/web/static.go COPY ./web/static.go /work/web/static.go
COPY ./internal /work/internal COPY ./internal /work/internal
COPY ./go.mod /work/go.mod COPY ./go.mod /work/go.mod
COPY ./go.sum /work/go.sum COPY ./go.sum /work/go.sum
RUN go build -o /work/bin/authentik ./cmd/server/ RUN go build -o /work/authentik ./cmd/server/
# Stage 5: MaxMind GeoIP # Stage 5: MaxMind GeoIP
FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip
@ -62,7 +61,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 6: Run # Stage 6: Run
FROM docker.io/python:3.11.5-slim-bookworm AS final-image FROM docker.io/python:3.11.4-slim-bullseye AS final-image
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ARG VERSION ARG VERSION
@ -82,13 +81,13 @@ COPY --from=geoip /usr/share/GeoIP /geoip
RUN apt-get update && \ RUN apt-get update && \
# Required for installing pip packages # Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev python3-dev && \ apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev && \
# Required for runtime # Required for runtime
apt-get install -y --no-install-recommends libpq5 openssl libxmlsec1-openssl libmaxminddb0 && \ apt-get install -y --no-install-recommends libxmlsec1-openssl libmaxminddb0 && \
# Required for bootstrap & healtcheck # Required for bootstrap & healtcheck
apt-get install -y --no-install-recommends runit && \ apt-get install -y --no-install-recommends runit && \
pip install --no-cache-dir -r /requirements.txt && \ pip install --no-cache-dir -r /requirements.txt && \
apt-get remove --purge -y build-essential pkg-config libxmlsec1-dev libpq-dev python3-dev && \ apt-get remove --purge -y build-essential pkg-config libxmlsec1-dev && \
apt-get autoremove --purge -y && \ apt-get autoremove --purge -y && \
apt-get clean && \ apt-get clean && \
rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \ rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \
@ -105,7 +104,7 @@ COPY ./tests /tests
COPY ./manage.py / COPY ./manage.py /
COPY ./blueprints /blueprints COPY ./blueprints /blueprints
COPY ./lifecycle/ /lifecycle COPY ./lifecycle/ /lifecycle
COPY --from=go-builder /work/bin/authentik /bin/authentik COPY --from=go-builder /work/authentik /bin/authentik
COPY --from=web-builder /work/web/dist/ /web/dist/ COPY --from=web-builder /work/web/dist/ /web/dist/
COPY --from=web-builder /work/web/authentik/ /web/authentik/ COPY --from=web-builder /work/web/authentik/ /web/authentik/
COPY --from=website-builder /work/website/help/ /website/help/ COPY --from=website-builder /work/website/help/ /website/help/

View File

@ -140,15 +140,13 @@ web-watch:
touch web/dist/.gitkeep touch web/dist/.gitkeep
cd web && npm run watch cd web && npm run watch
web-storybook-watch:
cd web && npm run storybook
web-lint-fix: web-lint-fix:
cd web && npm run prettier cd web && npm run prettier
web-lint: web-lint:
cd web && npm run lint cd web && npm run lint
cd web && npm run lit-analyse # TODO: The analyzer hasn't run correctly in awhile.
# cd web && npm run lit-analyse
web-check-compile: web-check-compile:
cd web && npm run tsc cd web && npm run tsc

View File

@ -16,8 +16,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
| Version | Supported | | Version | Supported |
| --- | --- | | --- | --- |
| 2023.5.x | ✅ |
| 2023.6.x | ✅ | | 2023.6.x | ✅ |
| 2023.8.x | ✅ |
## Reporting a Vulnerability ## Reporting a Vulnerability
@ -27,8 +27,6 @@ To report a vulnerability, send an email to [security@goauthentik.io](mailto:se
authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories: authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories:
| Score | Severity |
| --- | --- |
| 0.0 | None | | 0.0 | None |
| 0.1 3.9 | Low | | 0.1 3.9 | Low |
| 4.0 6.9 | Medium | | 4.0 6.9 | Medium |

View File

@ -1,8 +1,8 @@
"""authentik root module""" """authentik"""
from os import environ from os import environ
from typing import Optional from typing import Optional
__version__ = "2023.8.1" __version__ = "2023.6.1"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -2,43 +2,6 @@
from rest_framework import pagination from rest_framework import pagination
from rest_framework.response import Response from rest_framework.response import Response
PAGINATION_COMPONENT_NAME = "Pagination"
PAGINATION_SCHEMA = {
"type": "object",
"properties": {
"next": {
"type": "number",
},
"previous": {
"type": "number",
},
"count": {
"type": "number",
},
"current": {
"type": "number",
},
"total_pages": {
"type": "number",
},
"start_index": {
"type": "number",
},
"end_index": {
"type": "number",
},
},
"required": [
"next",
"previous",
"count",
"current",
"total_pages",
"start_index",
"end_index",
],
}
class Pagination(pagination.PageNumberPagination): class Pagination(pagination.PageNumberPagination):
"""Pagination which includes total pages and current page""" """Pagination which includes total pages and current page"""
@ -72,7 +35,41 @@ class Pagination(pagination.PageNumberPagination):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"pagination": {"$ref": f"#/components/schemas/{PAGINATION_COMPONENT_NAME}"}, "pagination": {
"type": "object",
"properties": {
"next": {
"type": "number",
},
"previous": {
"type": "number",
},
"count": {
"type": "number",
},
"current": {
"type": "number",
},
"total_pages": {
"type": "number",
},
"start_index": {
"type": "number",
},
"end_index": {
"type": "number",
},
},
"required": [
"next",
"previous",
"count",
"current",
"total_pages",
"start_index",
"end_index",
],
},
"results": schema, "results": schema,
}, },
"required": ["pagination", "results"], "required": ["pagination", "results"],

View File

@ -1,6 +1,5 @@
"""Error Response schema, from https://github.com/axnsan12/drf-yasg/issues/224""" """Error Response schema, from https://github.com/axnsan12/drf-yasg/issues/224"""
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from drf_spectacular.generators import SchemaGenerator
from drf_spectacular.plumbing import ( from drf_spectacular.plumbing import (
ResolvedComponent, ResolvedComponent,
build_array_type, build_array_type,
@ -9,9 +8,6 @@ from drf_spectacular.plumbing import (
) )
from drf_spectacular.settings import spectacular_settings from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from rest_framework.settings import api_settings
from authentik.api.pagination import PAGINATION_COMPONENT_NAME, PAGINATION_SCHEMA
def build_standard_type(obj, **kwargs): def build_standard_type(obj, **kwargs):
@ -32,7 +28,7 @@ GENERIC_ERROR = build_object_type(
VALIDATION_ERROR = build_object_type( VALIDATION_ERROR = build_object_type(
description=_("Validation Error"), description=_("Validation Error"),
properties={ properties={
api_settings.NON_FIELD_ERRORS_KEY: build_array_type(build_standard_type(OpenApiTypes.STR)), "non_field_errors": build_array_type(build_standard_type(OpenApiTypes.STR)),
"code": build_standard_type(OpenApiTypes.STR), "code": build_standard_type(OpenApiTypes.STR),
}, },
required=[], required=[],
@ -40,19 +36,7 @@ VALIDATION_ERROR = build_object_type(
) )
def create_component(generator: SchemaGenerator, name, schema, type_=ResolvedComponent.SCHEMA): def postprocess_schema_responses(result, generator, **kwargs): # noqa: W0613
"""Register a component and return a reference to it."""
component = ResolvedComponent(
name=name,
type=type_,
schema=schema,
object=name,
)
generator.registry.register_on_missing(component)
return component
def postprocess_schema_responses(result, generator: SchemaGenerator, **kwargs): # noqa: W0613
"""Workaround to set a default response for endpoints. """Workaround to set a default response for endpoints.
Workaround suggested at Workaround suggested at
<https://github.com/tfranzel/drf-spectacular/issues/119#issuecomment-656970357> <https://github.com/tfranzel/drf-spectacular/issues/119#issuecomment-656970357>
@ -60,10 +44,19 @@ def postprocess_schema_responses(result, generator: SchemaGenerator, **kwargs):
<https://github.com/tfranzel/drf-spectacular/issues/101>. <https://github.com/tfranzel/drf-spectacular/issues/101>.
""" """
create_component(generator, PAGINATION_COMPONENT_NAME, PAGINATION_SCHEMA) def create_component(name, schema, type_=ResolvedComponent.SCHEMA):
"""Register a component and return a reference to it."""
component = ResolvedComponent(
name=name,
type=type_,
schema=schema,
object=name,
)
generator.registry.register_on_missing(component)
return component
generic_error = create_component(generator, "GenericError", GENERIC_ERROR) generic_error = create_component("GenericError", GENERIC_ERROR)
validation_error = create_component(generator, "ValidationError", VALIDATION_ERROR) validation_error = create_component("ValidationError", VALIDATION_ERROR)
for path in result["paths"].values(): for path in result["paths"].values():
for method in path.values(): for method in path.values():

View File

@ -93,10 +93,10 @@ class ConfigView(APIView):
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
}, },
"capabilities": self.get_capabilities(), "capabilities": self.get_capabilities(),
"cache_timeout": CONFIG.get_int("redis.cache_timeout"), "cache_timeout": int(CONFIG.get("redis.cache_timeout")),
"cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"), "cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
"cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"), "cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
"cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"), "cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
} }
) )

View File

@ -45,8 +45,3 @@ entries:
attrs: attrs:
name: "%(uid)s" name: "%(uid)s"
password: "%(uid)s" password: "%(uid)s"
- model: authentik_core.user
identifiers:
username: "%(uid)s-no-password"
attrs:
name: "%(uid)s"

View File

@ -7,5 +7,7 @@ entries:
state: absent state: absent
- identifiers: - identifiers:
name: "%(id)s" name: "%(id)s"
expression: |
return True
model: authentik_policies_expression.expressionpolicy model: authentik_policies_expression.expressionpolicy
state: absent state: absent

View File

@ -9,8 +9,6 @@ context:
mapping: mapping:
key1: value key1: value
key2: 2 key2: 2
context1: context-nested-value
context2: !Context context1
entries: entries:
- model: !Format ["%s", authentik_sources_oauth.oauthsource] - model: !Format ["%s", authentik_sources_oauth.oauthsource]
state: !Format ["%s", present] state: !Format ["%s", present]
@ -36,7 +34,6 @@ entries:
model: authentik_policies_expression.expressionpolicy model: authentik_policies_expression.expressionpolicy
- attrs: - attrs:
attributes: attributes:
env_null: !Env [bar-baz, null]
policy_pk1: policy_pk1:
!Format [ !Format [
"%s-%s", "%s-%s",
@ -100,7 +97,6 @@ entries:
[list, with, items, !Format ["foo-%s", !Context foo]], [list, with, items, !Format ["foo-%s", !Context foo]],
] ]
if_true_simple: !If [!Context foo, true, text] if_true_simple: !If [!Context foo, true, text]
if_short: !If [!Context foo]
if_false_simple: !If [null, false, 2] if_false_simple: !If [null, false, 2]
enumerate_mapping_to_mapping: !Enumerate [ enumerate_mapping_to_mapping: !Enumerate [
!Context mapping, !Context mapping,
@ -145,7 +141,6 @@ entries:
] ]
] ]
] ]
nested_context: !Context context2
identifiers: identifiers:
name: test name: test
conditions: conditions:

View File

@ -155,7 +155,6 @@ class TestBlueprintsV1(TransactionTestCase):
}, },
"if_false_complex": ["list", "with", "items", "foo-bar"], "if_false_complex": ["list", "with", "items", "foo-bar"],
"if_true_simple": True, "if_true_simple": True,
"if_short": True,
"if_false_simple": 2, "if_false_simple": 2,
"enumerate_mapping_to_mapping": { "enumerate_mapping_to_mapping": {
"prefix-key1": "other-prefix-value", "prefix-key1": "other-prefix-value",
@ -212,10 +211,8 @@ class TestBlueprintsV1(TransactionTestCase):
], ],
}, },
}, },
"nested_context": "context-nested-value",
"env_null": None,
} }
).exists() )
) )
self.assertTrue( self.assertTrue(
OAuthSource.objects.filter( OAuthSource.objects.filter(

View File

@ -51,9 +51,3 @@ class TestBlueprintsV1ConditionalFields(TransactionTestCase):
user: User = User.objects.filter(username=self.uid).first() user: User = User.objects.filter(username=self.uid).first()
self.assertIsNotNone(user) self.assertIsNotNone(user)
self.assertTrue(user.check_password(self.uid)) self.assertTrue(user.check_password(self.uid))
def test_user_null(self):
"""Test user"""
user: User = User.objects.filter(username=f"{self.uid}-no-password").first()
self.assertIsNotNone(user)
self.assertFalse(user.has_usable_password())

View File

@ -223,11 +223,11 @@ class Env(YAMLTag):
if isinstance(node, ScalarNode): if isinstance(node, ScalarNode):
self.key = node.value self.key = node.value
if isinstance(node, SequenceNode): if isinstance(node, SequenceNode):
self.key = loader.construct_object(node.value[0]) self.key = node.value[0].value
self.default = loader.construct_object(node.value[1]) self.default = node.value[1].value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
return getenv(self.key) or self.default return getenv(self.key, self.default)
class Context(YAMLTag): class Context(YAMLTag):
@ -242,15 +242,13 @@ class Context(YAMLTag):
if isinstance(node, ScalarNode): if isinstance(node, ScalarNode):
self.key = node.value self.key = node.value
if isinstance(node, SequenceNode): if isinstance(node, SequenceNode):
self.key = loader.construct_object(node.value[0]) self.key = node.value[0].value
self.default = loader.construct_object(node.value[1]) self.default = node.value[1].value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
value = self.default value = self.default
if self.key in blueprint.context: if self.key in blueprint.context:
value = blueprint.context[self.key] value = blueprint.context[self.key]
if isinstance(value, YAMLTag):
return value.resolve(entry, blueprint)
return value return value
@ -262,7 +260,7 @@ class Format(YAMLTag):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.format_string = loader.construct_object(node.value[0]) self.format_string = node.value[0].value
self.args = [] self.args = []
for raw_node in node.value[1:]: for raw_node in node.value[1:]:
self.args.append(loader.construct_object(raw_node)) self.args.append(loader.construct_object(raw_node))
@ -341,7 +339,7 @@ class Condition(YAMLTag):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.mode = loader.construct_object(node.value[0]) self.mode = node.value[0].value
self.args = [] self.args = []
for raw_node in node.value[1:]: for raw_node in node.value[1:]:
self.args.append(loader.construct_object(raw_node)) self.args.append(loader.construct_object(raw_node))
@ -374,12 +372,8 @@ class If(YAMLTag):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.condition = loader.construct_object(node.value[0]) self.condition = loader.construct_object(node.value[0])
if len(node.value) == 1: self.when_true = loader.construct_object(node.value[1])
self.when_true = True self.when_false = loader.construct_object(node.value[2])
self.when_false = False
else:
self.when_true = loader.construct_object(node.value[1])
self.when_false = loader.construct_object(node.value[2])
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
if isinstance(self.condition, YAMLTag): if isinstance(self.condition, YAMLTag):
@ -416,7 +410,7 @@ class Enumerate(YAMLTag, YAMLTagContext):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.iterable = loader.construct_object(node.value[0]) self.iterable = loader.construct_object(node.value[0])
self.output_body = loader.construct_object(node.value[1]) self.output_body = node.value[1].value
self.item_body = loader.construct_object(node.value[2]) self.item_body = loader.construct_object(node.value[2])
self.__current_context: tuple[Any, Any] = tuple() self.__current_context: tuple[Any, Any] = tuple()

View File

@ -35,7 +35,6 @@ from authentik.core.models import (
Source, Source,
UserSourceConnection, UserSourceConnection,
) )
from authentik.events.utils import cleanse_dict
from authentik.flows.models import FlowToken, Stage from authentik.flows.models import FlowToken, Stage
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.outposts.models import OutpostServiceConnection from authentik.outposts.models import OutpostServiceConnection
@ -200,6 +199,9 @@ class Importer:
serializer_kwargs = {} serializer_kwargs = {}
model_instance = existing_models.first() model_instance = existing_models.first()
if not isinstance(model(), BaseMetaModel) and model_instance: if not isinstance(model(), BaseMetaModel) and model_instance:
if entry.get_state(self.__import) == BlueprintEntryDesiredState.CREATED:
self.logger.debug("instance exists, skipping")
return None
self.logger.debug( self.logger.debug(
"initialise serializer with instance", "initialise serializer with instance",
model=model, model=model,
@ -210,9 +212,7 @@ class Importer:
serializer_kwargs["partial"] = True serializer_kwargs["partial"] = True
else: else:
self.logger.debug( self.logger.debug(
"initialised new serializer instance", "initialised new serializer instance", model=model, **updated_identifiers
model=model,
**cleanse_dict(updated_identifiers),
) )
model_instance = model() model_instance = model()
# pk needs to be set on the model instance otherwise a new one will be generated # pk needs to be set on the model instance otherwise a new one will be generated
@ -268,34 +268,21 @@ class Importer:
try: try:
serializer = self._validate_single(entry) serializer = self._validate_single(entry)
except EntryInvalidError as exc: except EntryInvalidError as exc:
# For deleting objects we don't need the serializer to be valid
if entry.get_state(self.__import) == BlueprintEntryDesiredState.ABSENT:
continue
self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc)
return False return False
if not serializer: if not serializer:
continue continue
state = entry.get_state(self.__import) state = entry.get_state(self.__import)
if state in [BlueprintEntryDesiredState.PRESENT, BlueprintEntryDesiredState.CREATED]: if state in [
instance = serializer.instance BlueprintEntryDesiredState.PRESENT,
if ( BlueprintEntryDesiredState.CREATED,
instance ]:
and not instance._state.adding model = serializer.save()
and state == BlueprintEntryDesiredState.CREATED
):
self.logger.debug(
"instance exists, skipping",
model=model,
instance=instance,
pk=instance.pk,
)
else:
instance = serializer.save()
self.logger.debug("updated model", model=instance)
if "pk" in entry.identifiers: if "pk" in entry.identifiers:
self.__pk_map[entry.identifiers["pk"]] = instance.pk self.__pk_map[entry.identifiers["pk"]] = model.pk
entry._state = BlueprintEntryState(instance) entry._state = BlueprintEntryState(model)
self.logger.debug("updated model", model=model)
elif state == BlueprintEntryDesiredState.ABSENT: elif state == BlueprintEntryDesiredState.ABSENT:
instance: Optional[Model] = serializer.instance instance: Optional[Model] = serializer.instance
if instance.pk: if instance.pk:
@ -322,6 +309,5 @@ class Importer:
self.logger.debug("Blueprint validation failed") self.logger.debug("Blueprint validation failed")
for log in logs: for log in logs:
getattr(self.logger, log.get("log_level"))(**log) getattr(self.logger, log.get("log_level"))(**log)
self.logger.debug("Finished blueprint import validation")
self.__import = orig_import self.__import = orig_import
return successful, logs return successful, logs

View File

@ -31,7 +31,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
required = attrs["required"] required = attrs["required"]
instance = BlueprintInstance.objects.filter(**identifiers).first() instance = BlueprintInstance.objects.filter(**identifiers).first()
if not instance and required: if not instance and required:
raise ValidationError({"identifiers": "Required blueprint does not exist"}) raise ValidationError("Required blueprint does not exist")
self.blueprint_instance = instance self.blueprint_instance = instance
return super().validate(attrs) return super().validate(attrs)

View File

@ -49,7 +49,7 @@ class GroupSerializer(ModelSerializer):
users_obj = ListSerializer( users_obj = ListSerializer(
child=GroupMemberSerializer(), read_only=True, source="users", required=False child=GroupMemberSerializer(), read_only=True, source="users", required=False
) )
parent_name = CharField(source="parent.name", read_only=True, allow_null=True) parent_name = CharField(source="parent.name", read_only=True)
num_pk = IntegerField(read_only=True) num_pk = IntegerField(read_only=True)

View File

@ -47,7 +47,7 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
attrs.setdefault("user", request.user) attrs.setdefault("user", request.user)
attrs.setdefault("intent", TokenIntents.INTENT_API) attrs.setdefault("intent", TokenIntents.INTENT_API)
if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]: if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]:
raise ValidationError({"intent": f"Invalid intent {attrs.get('intent')}"}) raise ValidationError(f"Invalid intent {attrs.get('intent')}")
return attrs return attrs
class Meta: class Meta:

View File

@ -15,13 +15,7 @@ from django.utils.http import urlencode
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_filters.filters import ( from django_filters.filters import BooleanFilter, CharFilter, ModelMultipleChoiceFilter, UUIDFilter
BooleanFilter,
CharFilter,
ModelMultipleChoiceFilter,
MultipleChoiceFilter,
UUIDFilter,
)
from django_filters.filterset import FilterSet from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import ( from drf_spectacular.utils import (
@ -123,35 +117,27 @@ class UserSerializer(ModelSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if SERIALIZER_CONTEXT_BLUEPRINT in self.context: if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
self.fields["password"] = CharField(required=False, allow_null=True) self.fields["password"] = CharField(required=False)
def create(self, validated_data: dict) -> User: def create(self, validated_data: dict) -> User:
"""If this serializer is used in the blueprint context, we allow for """If this serializer is used in the blueprint context, we allow for
directly setting a password. However should be done via the `set_password` directly setting a password. However should be done via the `set_password`
method instead of directly setting it like rest_framework.""" method instead of directly setting it like rest_framework."""
password = validated_data.pop("password", None)
instance: User = super().create(validated_data) instance: User = super().create(validated_data)
self._set_password(instance, password) if SERIALIZER_CONTEXT_BLUEPRINT in self.context and "password" in validated_data:
instance.set_password(validated_data["password"])
instance.save()
return instance return instance
def update(self, instance: User, validated_data: dict) -> User: def update(self, instance: User, validated_data: dict) -> User:
"""Same as `create` above, set the password directly if we're in a blueprint """Same as `create` above, set the password directly if we're in a blueprint
context""" context"""
password = validated_data.pop("password", None)
instance = super().update(instance, validated_data) instance = super().update(instance, validated_data)
self._set_password(instance, password) if SERIALIZER_CONTEXT_BLUEPRINT in self.context and "password" in validated_data:
instance.set_password(validated_data["password"])
instance.save()
return instance return instance
def _set_password(self, instance: User, password: Optional[str]):
"""Set password of user if we're in a blueprint context, and if it's an empty
string then use an unusable password"""
if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password:
instance.set_password(password)
instance.save()
if len(instance.password) == 0:
instance.set_unusable_password()
instance.save()
def validate_path(self, path: str) -> str: def validate_path(self, path: str) -> str:
"""Validate path""" """Validate path"""
if path[:1] == "/" or path[-1] == "/": if path[:1] == "/" or path[-1] == "/":
@ -215,7 +201,7 @@ class UserSelfSerializer(ModelSerializer):
) )
def get_groups(self, _: User): def get_groups(self, _: User):
"""Return only the group names a user is member of""" """Return only the group names a user is member of"""
for group in self.instance.all_groups().order_by("name"): for group in self.instance.ak_groups.all():
yield { yield {
"name": group.name, "name": group.name,
"pk": group.pk, "pk": group.pk,
@ -314,11 +300,11 @@ class UsersFilter(FilterSet):
is_superuser = BooleanFilter(field_name="ak_groups", lookup_expr="is_superuser") is_superuser = BooleanFilter(field_name="ak_groups", lookup_expr="is_superuser")
uuid = UUIDFilter(field_name="uuid") uuid = UUIDFilter(field_name="uuid")
path = CharFilter(field_name="path") path = CharFilter(
field_name="path",
)
path_startswith = CharFilter(field_name="path", lookup_expr="startswith") path_startswith = CharFilter(field_name="path", lookup_expr="startswith")
type = MultipleChoiceFilter(choices=UserTypes.choices, field_name="type")
groups_by_name = ModelMultipleChoiceFilter( groups_by_name = ModelMultipleChoiceFilter(
field_name="ak_groups__name", field_name="ak_groups__name",
to_field_name="name", to_field_name="name",

View File

@ -1,21 +0,0 @@
"""Build source docs"""
from pathlib import Path
from django.core.management.base import BaseCommand
from pdoc import pdoc
from pdoc.render import configure
class Command(BaseCommand):
"""Build source docs"""
def handle(self, **options):
configure(
docformat="markdown",
mermaid=True,
logo="https://goauthentik.io/img/icon_top_brand_colour.svg",
)
pdoc(
"authentik",
output_directory=Path("./source_docs"),
)

View File

@ -1,11 +1,55 @@
# Generated by Django 3.2.8 on 2021-10-10 16:16 # Generated by Django 3.2.8 on 2021-10-10 16:16
from os import environ
import django.db.models.deletion import django.db.models.deletion
from django.apps.registry import Apps
from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
import authentik.core.models import authentik.core.models
def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from django.contrib.auth.hashers import make_password
User = apps.get_model("authentik_core", "User")
db_alias = schema_editor.connection.alias
akadmin, _ = User.objects.using(db_alias).get_or_create(
username="akadmin",
email=environ.get("AUTHENTIK_BOOTSTRAP_EMAIL", "root@localhost"),
name="authentik Default Admin",
)
password = None
if "TF_BUILD" in environ or settings.TEST:
password = "akadmin" # noqa # nosec
if "AUTHENTIK_BOOTSTRAP_PASSWORD" in environ:
password = environ["AUTHENTIK_BOOTSTRAP_PASSWORD"]
if password:
akadmin.password = make_password(password)
else:
akadmin.password = make_password(None)
akadmin.save()
def create_default_admin_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group")
User = apps.get_model("authentik_core", "User")
# Creates a default admin group
group, _ = Group.objects.using(db_alias).get_or_create(
is_superuser=True,
defaults={
"name": "authentik Admins",
},
)
group.users.set(User.objects.filter(username="akadmin"))
group.save()
class Migration(migrations.Migration): class Migration(migrations.Migration):
replaces = [ replaces = [
("authentik_core", "0002_auto_20200523_1133"), ("authentik_core", "0002_auto_20200523_1133"),
@ -75,6 +119,9 @@ class Migration(migrations.Migration):
model_name="user", model_name="user",
name="is_staff", name="is_staff",
), ),
migrations.RunPython(
code=create_default_user,
),
migrations.AddField( migrations.AddField(
model_name="user", model_name="user",
name="is_superuser", name="is_superuser",
@ -154,6 +201,9 @@ class Migration(migrations.Migration):
default=False, help_text="Users added to this group will be superusers." default=False, help_text="Users added to this group will be superusers."
), ),
), ),
migrations.RunPython(
code=create_default_admin_group,
),
migrations.AlterModelManagers( migrations.AlterModelManagers(
name="user", name="user",
managers=[ managers=[

View File

@ -1,6 +1,7 @@
# Generated by Django 3.2.8 on 2021-10-10 16:12 # Generated by Django 3.2.8 on 2021-10-10 16:12
import uuid import uuid
from os import environ
import django.db.models.deletion import django.db.models.deletion
from django.apps.registry import Apps from django.apps.registry import Apps
@ -34,6 +35,29 @@ def fix_duplicates(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Token.objects.using(db_alias).filter(identifier=ident["identifier"]).delete() Token.objects.using(db_alias).filter(identifier=ident["identifier"]).delete()
def create_default_user_token(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from authentik.core.models import TokenIntents
User = apps.get_model("authentik_core", "User")
Token = apps.get_model("authentik_core", "Token")
db_alias = schema_editor.connection.alias
akadmin = User.objects.using(db_alias).filter(username="akadmin")
if not akadmin.exists():
return
if "AUTHENTIK_BOOTSTRAP_TOKEN" not in environ:
return
key = environ["AUTHENTIK_BOOTSTRAP_TOKEN"]
Token.objects.using(db_alias).create(
identifier="authentik-bootstrap-token",
user=akadmin.first(),
intent=TokenIntents.INTENT_API,
expiring=False,
key=key,
)
class Migration(migrations.Migration): class Migration(migrations.Migration):
replaces = [ replaces = [
("authentik_core", "0018_auto_20210330_1345"), ("authentik_core", "0018_auto_20210330_1345"),
@ -190,6 +214,9 @@ class Migration(migrations.Migration):
"verbose_name_plural": "Authenticated Sessions", "verbose_name_plural": "Authenticated Sessions",
}, },
), ),
migrations.RunPython(
code=create_default_user_token,
),
migrations.AlterField( migrations.AlterField(
model_name="token", model_name="token",
name="intent", name="intent",

View File

@ -60,7 +60,7 @@ def default_token_key():
"""Default token key""" """Default token key"""
# We use generate_id since the chars in the key should be easy # We use generate_id since the chars in the key should be easy
# to use in Emails (for verification) and URLs (for recovery) # to use in Emails (for verification) and URLs (for recovery)
return generate_id(CONFIG.get_int("default_token_length")) return generate_id(int(CONFIG.get("default_token_length")))
class UserTypes(models.TextChoices): class UserTypes(models.TextChoices):
@ -79,7 +79,7 @@ class UserTypes(models.TextChoices):
class Group(SerializerModel): class Group(SerializerModel):
"""Group model which supports a basic hierarchy and has attributes""" """Custom Group model which supports a basic hierarchy"""
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
@ -113,7 +113,27 @@ class Group(SerializerModel):
def is_member(self, user: "User") -> bool: def is_member(self, user: "User") -> bool:
"""Recursively check if `user` is member of us, or any parent.""" """Recursively check if `user` is member of us, or any parent."""
return user.all_groups().filter(group_uuid=self.group_uuid).exists() query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid = %s
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth - 1
FROM authentik_core_group,parents
WHERE (
authentik_core_group.parent_id = parents.group_uuid and
parents.relative_depth > -20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid;
"""
groups = Group.objects.raw(query, [self.group_uuid])
return user.ak_groups.filter(pk__in=[group.pk for group in groups]).exists()
def __str__(self): def __str__(self):
return f"Group {self.name}" return f"Group {self.name}"
@ -128,15 +148,15 @@ class Group(SerializerModel):
class UserManager(DjangoUserManager): class UserManager(DjangoUserManager):
"""User manager that doesn't assign is_superuser and is_staff""" """Custom user manager that doesn't assign is_superuser and is_staff"""
def create_user(self, username, email=None, password=None, **extra_fields): def create_user(self, username, email=None, password=None, **extra_fields):
"""User manager that doesn't assign is_superuser and is_staff""" """Custom user manager that doesn't assign is_superuser and is_staff"""
return self._create_user(username, email, password, **extra_fields) return self._create_user(username, email, password, **extra_fields)
class User(SerializerModel, GuardianUserMixin, AbstractUser): class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model.""" """Custom User model to allow easier adding of user-based settings"""
uuid = models.UUIDField(default=uuid4, editable=False, unique=True) uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
name = models.TextField(help_text=_("User's display name.")) name = models.TextField(help_text=_("User's display name."))
@ -156,45 +176,13 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""Get the default user path""" """Get the default user path"""
return User._meta.get_field("path").default return User._meta.get_field("path").default
def all_groups(self) -> QuerySet[Group]:
"""Recursively get all groups this user is a member of.
At least one query is done to get the direct groups of the user, with groups
there are at most 3 queries done"""
direct_groups = list(
x for x in self.ak_groups.all().values_list("pk", flat=True).iterator()
)
if len(direct_groups) < 1:
return Group.objects.none()
query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid = ANY(%s)
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth + 1
FROM authentik_core_group, parents
WHERE (
authentik_core_group.group_uuid = parents.parent_id and
parents.relative_depth < 20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid, name
ORDER BY name;
"""
group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()]
return Group.objects.filter(pk__in=group_pks)
def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]: def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]:
"""Get a dictionary containing the attributes from all groups the user belongs to, """Get a dictionary containing the attributes from all groups the user belongs to,
including the users attributes""" including the users attributes"""
final_attributes = {} final_attributes = {}
if request and hasattr(request, "tenant"): if request and hasattr(request, "tenant"):
always_merger.merge(final_attributes, request.tenant.attributes) always_merger.merge(final_attributes, request.tenant.attributes)
for group in self.all_groups().order_by("name"): for group in self.ak_groups.all().order_by("name"):
always_merger.merge(final_attributes, group.attributes) always_merger.merge(final_attributes, group.attributes)
always_merger.merge(final_attributes, self.attributes) always_merger.merge(final_attributes, self.attributes)
return final_attributes return final_attributes
@ -208,7 +196,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
@cached_property @cached_property
def is_superuser(self) -> bool: def is_superuser(self) -> bool:
"""Get supseruser status based on membership in a group with superuser status""" """Get supseruser status based on membership in a group with superuser status"""
return self.all_groups().filter(is_superuser=True).exists() return self.ak_groups.filter(is_superuser=True).exists()
@property @property
def is_staff(self) -> bool: def is_staff(self) -> bool:

View File

@ -78,6 +78,7 @@
</main> </main>
{% endblock %} {% endblock %}
<footer class="pf-c-login__footer"> <footer class="pf-c-login__footer">
<p></p>
<ul class="pf-c-list pf-m-inline"> <ul class="pf-c-list pf-m-inline">
{% for link in footer_links %} {% for link in footer_links %}
<li> <li>

View File

@ -13,9 +13,7 @@ class TestGroups(TestCase):
user = User.objects.create(username=generate_id()) user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id())
group = Group.objects.create(name=generate_id()) group = Group.objects.create(name=generate_id())
other_group = Group.objects.create(name=generate_id())
group.users.add(user) group.users.add(user)
other_group.users.add(user)
self.assertTrue(group.is_member(user)) self.assertTrue(group.is_member(user))
self.assertFalse(group.is_member(user2)) self.assertFalse(group.is_member(user2))
@ -23,26 +21,22 @@ class TestGroups(TestCase):
"""Test parent membership""" """Test parent membership"""
user = User.objects.create(username=generate_id()) user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id())
parent = Group.objects.create(name=generate_id()) first = Group.objects.create(name=generate_id())
child = Group.objects.create(name=generate_id(), parent=parent) second = Group.objects.create(name=generate_id(), parent=first)
child.users.add(user) second.users.add(user)
self.assertTrue(child.is_member(user)) self.assertTrue(first.is_member(user))
self.assertTrue(parent.is_member(user)) self.assertFalse(first.is_member(user2))
self.assertFalse(child.is_member(user2))
self.assertFalse(parent.is_member(user2))
def test_group_membership_parent_extra(self): def test_group_membership_parent_extra(self):
"""Test parent membership""" """Test parent membership"""
user = User.objects.create(username=generate_id()) user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id())
parent = Group.objects.create(name=generate_id()) first = Group.objects.create(name=generate_id())
second = Group.objects.create(name=generate_id(), parent=parent) second = Group.objects.create(name=generate_id(), parent=first)
third = Group.objects.create(name=generate_id(), parent=second) third = Group.objects.create(name=generate_id(), parent=second)
second.users.add(user) second.users.add(user)
self.assertTrue(parent.is_member(user)) self.assertTrue(first.is_member(user))
self.assertFalse(parent.is_member(user2)) self.assertFalse(first.is_member(user2))
self.assertTrue(second.is_member(user))
self.assertFalse(second.is_member(user2))
self.assertFalse(third.is_member(user)) self.assertFalse(third.is_member(user))
self.assertFalse(third.is_member(user2)) self.assertFalse(third.is_member(user2))

View File

@ -28,19 +28,6 @@ class TestUsersAPI(APITestCase):
self.admin = create_test_admin_user() self.admin = create_test_admin_user()
self.user = User.objects.create(username="test-user") self.user = User.objects.create(username="test-user")
def test_filter_type(self):
"""Test API filtering by type"""
self.client.force_login(self.admin)
user = create_test_admin_user(type=UserTypes.EXTERNAL)
response = self.client.get(
reverse("authentik_api:user-list"),
data={
"type": UserTypes.EXTERNAL,
"username": user.username,
},
)
self.assertEqual(response.status_code, 200)
def test_metrics(self): def test_metrics(self):
"""Test user's metrics""" """Test user's metrics"""
self.client.force_login(self.admin) self.client.force_login(self.admin)

View File

@ -21,7 +21,7 @@ def create_test_flow(
) )
def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User: def create_test_admin_user(name: Optional[str] = None) -> User:
"""Generate a test-admin user""" """Generate a test-admin user"""
uid = generate_id(20) if not name else name uid = generate_id(20) if not name else name
group = Group.objects.create(name=uid, is_superuser=True) group = Group.objects.create(name=uid, is_superuser=True)
@ -29,7 +29,6 @@ def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User:
username=uid, username=uid,
name=uid, name=uid,
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
**kwargs,
) )
user.set_password(uid) user.set_password(uid)
user.save() user.save()
@ -37,12 +36,12 @@ def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User:
return user return user
def create_test_tenant(**kwargs) -> Tenant: def create_test_tenant() -> Tenant:
"""Generate a test tenant, removing all other tenants to make sure this one """Generate a test tenant, removing all other tenants to make sure this one
matches.""" matches."""
uid = generate_id(20) uid = generate_id(20)
Tenant.objects.all().delete() Tenant.objects.all().delete()
return Tenant.objects.create(domain=uid, default=True, **kwargs) return Tenant.objects.create(domain=uid, default=True)
def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair: def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair:

View File

@ -35,13 +35,13 @@ class LicenseSerializer(ModelSerializer):
"name", "name",
"key", "key",
"expiry", "expiry",
"internal_users", "users",
"external_users", "external_users",
] ]
extra_kwargs = { extra_kwargs = {
"name": {"read_only": True}, "name": {"read_only": True},
"expiry": {"read_only": True}, "expiry": {"read_only": True},
"internal_users": {"read_only": True}, "users": {"read_only": True},
"external_users": {"read_only": True}, "external_users": {"read_only": True},
} }
@ -49,7 +49,7 @@ class LicenseSerializer(ModelSerializer):
class LicenseSummary(PassiveSerializer): class LicenseSummary(PassiveSerializer):
"""Serializer for license status""" """Serializer for license status"""
internal_users = IntegerField(required=True) users = IntegerField(required=True)
external_users = IntegerField(required=True) external_users = IntegerField(required=True)
valid = BooleanField() valid = BooleanField()
show_admin_warning = BooleanField() show_admin_warning = BooleanField()
@ -62,9 +62,9 @@ class LicenseSummary(PassiveSerializer):
class LicenseForecastSerializer(PassiveSerializer): class LicenseForecastSerializer(PassiveSerializer):
"""Serializer for license forecast""" """Serializer for license forecast"""
internal_users = IntegerField(required=True) users = IntegerField(required=True)
external_users = IntegerField(required=True) external_users = IntegerField(required=True)
forecasted_internal_users = IntegerField(required=True) forecasted_users = IntegerField(required=True)
forecasted_external_users = IntegerField(required=True) forecasted_external_users = IntegerField(required=True)
@ -111,7 +111,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
latest_valid = datetime.fromtimestamp(total.exp) latest_valid = datetime.fromtimestamp(total.exp)
response = LicenseSummary( response = LicenseSummary(
data={ data={
"internal_users": total.internal_users, "users": total.users,
"external_users": total.external_users, "external_users": total.external_users,
"valid": total.is_valid(), "valid": total.is_valid(),
"show_admin_warning": show_admin_warning, "show_admin_warning": show_admin_warning,
@ -135,8 +135,8 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
def forecast(self, request: Request) -> Response: def forecast(self, request: Request) -> Response:
"""Forecast how many users will be required in a year""" """Forecast how many users will be required in a year"""
last_month = now() - timedelta(days=30) last_month = now() - timedelta(days=30)
# Forecast for internal users # Forecast for default users
internal_in_last_month = User.objects.filter( users_in_last_month = User.objects.filter(
type=UserTypes.INTERNAL, date_joined__gte=last_month type=UserTypes.INTERNAL, date_joined__gte=last_month
).count() ).count()
# Forecast for external users # Forecast for external users
@ -144,9 +144,9 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
forecast_for_months = 12 forecast_for_months = 12
response = LicenseForecastSerializer( response = LicenseForecastSerializer(
data={ data={
"internal_users": LicenseKey.get_default_user_count(), "users": LicenseKey.get_default_user_count(),
"external_users": LicenseKey.get_external_user_count(), "external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months), "forecasted_users": (users_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months), "forecasted_external_users": (external_in_last_month * forecast_for_months),
} }
) )

View File

@ -1,36 +0,0 @@
# Generated by Django 4.2.4 on 2023-08-23 10:06
import django.contrib.postgres.indexes
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_enterprise", "0001_initial"),
]
operations = [
migrations.RenameField(
model_name="license",
old_name="users",
new_name="internal_users",
),
migrations.AlterField(
model_name="license",
name="key",
field=models.TextField(),
),
migrations.AddIndex(
model_name="license",
index=django.contrib.postgres.indexes.HashIndex(
fields=["key"], name="authentik_e_key_523e13_hash"
),
),
migrations.AlterModelOptions(
name="licenseusage",
options={
"verbose_name": "License Usage",
"verbose_name_plural": "License Usage Records",
},
),
]

View File

@ -11,11 +11,9 @@ from uuid import uuid4
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict from dacite import from_dict
from django.contrib.postgres.indexes import HashIndex
from django.db import models from django.db import models
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from jwt import PyJWTError, decode, get_unverified_header from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
@ -48,8 +46,8 @@ class LicenseKey:
exp: int exp: int
name: str name: str
internal_users: int = 0 users: int
external_users: int = 0 external_users: int
flags: list[LicenseFlags] = field(default_factory=list) flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod @staticmethod
@ -89,7 +87,7 @@ class LicenseKey:
active_licenses = License.objects.filter(expiry__gte=now()) active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses: for lic in active_licenses:
total.internal_users += lic.internal_users total.users += lic.users
total.external_users += lic.external_users total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple())) exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0: if total.exp == 0:
@ -125,7 +123,7 @@ class LicenseKey:
Only checks the current count, no historical data is checked""" Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count() default_users = self.get_default_user_count()
if default_users > self.internal_users: if default_users > self.users:
return False return False
active_users = self.get_external_user_count() active_users = self.get_external_user_count()
if active_users > self.external_users: if active_users > self.external_users:
@ -155,11 +153,11 @@ class License(models.Model):
"""An authentik enterprise license""" """An authentik enterprise license"""
license_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) license_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
key = models.TextField() key = models.TextField(unique=True)
name = models.TextField() name = models.TextField()
expiry = models.DateTimeField() expiry = models.DateTimeField()
internal_users = models.BigIntegerField() users = models.BigIntegerField()
external_users = models.BigIntegerField() external_users = models.BigIntegerField()
@property @property
@ -167,9 +165,6 @@ class License(models.Model):
"""Get parsed license status""" """Get parsed license status"""
return LicenseKey.validate(self.key) return LicenseKey.validate(self.key)
class Meta:
indexes = (HashIndex(fields=("key",)),)
def usage_expiry(): def usage_expiry():
"""Keep license usage records for 3 months""" """Keep license usage records for 3 months"""
@ -188,7 +183,3 @@ class LicenseUsage(ExpiringModel):
within_limits = models.BooleanField() within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True) record_date = models.DateTimeField(auto_now_add=True)
class Meta:
verbose_name = _("License Usage")
verbose_name_plural = _("License Usage Records")

View File

@ -13,6 +13,6 @@ def pre_save_license(sender: type[License], instance: License, **_):
"""Extract data from license jwt and save it into model""" """Extract data from license jwt and save it into model"""
status = instance.status status = instance.status
instance.name = status.name instance.name = status.name
instance.internal_users = status.internal_users instance.users = status.users
instance.external_users = status.external_users instance.external_users = status.external_users
instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone())

View File

@ -23,7 +23,7 @@ class TestEnterpriseLicense(TestCase):
aud="", aud="",
exp=_exp, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, users=100,
external_users=100, external_users=100,
) )
), ),
@ -32,7 +32,7 @@ class TestEnterpriseLicense(TestCase):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.is_valid()) self.assertTrue(lic.status.is_valid())
self.assertEqual(lic.internal_users, 100) self.assertEqual(lic.users, 100)
def test_invalid(self): def test_invalid(self):
"""Test invalid license""" """Test invalid license"""
@ -46,7 +46,7 @@ class TestEnterpriseLicense(TestCase):
aud="", aud="",
exp=_exp, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, users=100,
external_users=100, external_users=100,
) )
), ),
@ -58,7 +58,7 @@ class TestEnterpriseLicense(TestCase):
lic2 = License.objects.create(key=generate_id()) lic2 = License.objects.create(key=generate_id())
self.assertTrue(lic2.status.is_valid()) self.assertTrue(lic2.status.is_valid())
total = LicenseKey.get_total() total = LicenseKey.get_total()
self.assertEqual(total.internal_users, 200) self.assertEqual(total.users, 200)
self.assertEqual(total.external_users, 200) self.assertEqual(total.external_users, 200)
self.assertEqual(total.exp, _exp) self.assertEqual(total.exp, _exp)
self.assertTrue(total.is_valid()) self.assertTrue(total.is_valid())

View File

@ -4,7 +4,7 @@ from json import loads
import django_filters import django_filters
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.fields.json import KeyTextTransform, KeyTransform from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import ExtractDay from django.db.models.functions import ExtractDay
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema from drf_spectacular.utils import OpenApiParameter, extend_schema
@ -134,11 +134,11 @@ class EventViewSet(ModelViewSet):
"""Get the top_n events grouped by user count""" """Get the top_n events grouped by user count"""
filtered_action = request.query_params.get("action", EventAction.LOGIN) filtered_action = request.query_params.get("action", EventAction.LOGIN)
top_n = int(request.query_params.get("top_n", "15")) top_n = int(request.query_params.get("top_n", "15"))
events = ( return Response(
get_objects_for_user(request.user, "authentik_events.view_event") get_objects_for_user(request.user, "authentik_events.view_event")
.filter(action=filtered_action) .filter(action=filtered_action)
.exclude(context__authorized_application=None) .exclude(context__authorized_application=None)
.annotate(application=KeyTransform("authorized_application", "context")) .annotate(application=KeyTextTransform("authorized_application", "context"))
.annotate(user_pk=KeyTextTransform("pk", "user")) .annotate(user_pk=KeyTextTransform("pk", "user"))
.values("application") .values("application")
.annotate(counted_events=Count("application")) .annotate(counted_events=Count("application"))
@ -146,7 +146,6 @@ class EventViewSet(ModelViewSet):
.values("unique_users", "application", "counted_events") .values("unique_users", "application", "counted_events")
.order_by("-counted_events")[:top_n] .order_by("-counted_events")[:top_n]
) )
return Response(EventTopPerUserSerializer(instance=events, many=True).data)
@extend_schema( @extend_schema(
methods=["GET"], methods=["GET"],

View File

@ -39,7 +39,7 @@ class NotificationTransportSerializer(ModelSerializer):
mode = attrs.get("mode") mode = attrs.get("mode")
if mode in [TransportMode.WEBHOOK, TransportMode.WEBHOOK_SLACK]: if mode in [TransportMode.WEBHOOK, TransportMode.WEBHOOK_SLACK]:
if "webhook_url" not in attrs or attrs.get("webhook_url", "") == "": if "webhook_url" not in attrs or attrs.get("webhook_url", "") == "":
raise ValidationError({"webhook_url": "Webhook URL may not be empty."}) raise ValidationError("Webhook URL may not be empty.")
return attrs return attrs
class Meta: class Meta:

View File

@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored. # was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored" PLAN_CONTEXT_IS_RESTORED = "is_restored"
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows") CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
CACHE_PREFIX = "goauthentik.io/flows/planner/" CACHE_PREFIX = "goauthentik.io/flows/planner/"

View File

@ -1,10 +0,0 @@
package lib
import _ "embed"
//go:embed default.yml
var defaultConfig []byte
func DefaultConfig() []byte {
return defaultConfig
}

View File

@ -213,14 +213,6 @@ class ConfigLoader:
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
return attr.value return attr.value
def get_int(self, path: str, default=0) -> int:
"""Wrapper for get that converts value into int"""
try:
return int(self.get(path, default))
except ValueError as exc:
self.log("warning", "Failed to parse config as int", path=path, exc=str(exc))
return default
def get_bool(self, path: str, default=False) -> bool: def get_bool(self, path: str, default=False) -> bool:
"""Wrapper for get that converts value into boolean""" """Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true" return str(self.get(path, default)).lower() == "true"

View File

@ -11,11 +11,7 @@ postgresql:
listen: listen:
listen_http: 0.0.0.0:9000 listen_http: 0.0.0.0:9000
listen_https: 0.0.0.0:9443 listen_https: 0.0.0.0:9443
listen_ldap: 0.0.0.0:3389
listen_ldaps: 0.0.0.0:6636
listen_radius: 0.0.0.0:1812
listen_metrics: 0.0.0.0:9300 listen_metrics: 0.0.0.0:9300
listen_debug: 0.0.0.0:9900
trusted_proxy_cidrs: trusted_proxy_cidrs:
- 127.0.0.0/8 - 127.0.0.0/8
- 10.0.0.0/8 - 10.0.0.0/8
@ -36,9 +32,6 @@ redis:
cache_timeout_policies: 300 cache_timeout_policies: 300
cache_timeout_reputation: 300 cache_timeout_reputation: 300
paths:
media: ./media
debug: false debug: false
remote_debug: false remote_debug: false
@ -84,9 +77,6 @@ ldap:
tls: tls:
ciphers: null ciphers: null
reputation:
expiry: 86400
cookie_domain: null cookie_domain: null
disable_update_check: false disable_update_check: false
disable_startup_analytics: false disable_startup_analytics: false

View File

@ -112,7 +112,7 @@ class BaseEvaluator:
@staticmethod @staticmethod
def expr_is_group_member(user: User, **group_filters) -> bool: def expr_is_group_member(user: User, **group_filters) -> bool:
"""Check if `user` is member of group with name `group_name`""" """Check if `user` is member of group with name `group_name`"""
return user.all_groups().filter(**group_filters).exists() return user.ak_groups.filter(**group_filters).exists()
@staticmethod @staticmethod
def expr_user_by(**filters) -> Optional[User]: def expr_user_by(**filters) -> Optional[User]:

View File

@ -98,7 +98,7 @@ def traces_sampler(sampling_context: dict) -> float:
def before_send(event: dict, hint: dict) -> Optional[dict]: def before_send(event: dict, hint: dict) -> Optional[dict]:
"""Check if error is database error, and ignore if so""" """Check if error is database error, and ignore if so"""
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from psycopg.errors import Error from psycopg2.errors import Error
ignored_classes = ( ignored_classes = (
# Inbuilt types # Inbuilt types

View File

@ -79,15 +79,3 @@ class TestConfig(TestCase):
config.update_from_file(file2_name) config.update_from_file(file2_name)
unlink(file_name) unlink(file_name)
unlink(file2_name) unlink(file2_name)
def test_get_int(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", 1234)
self.assertEqual(config.get_int("foo"), 1234)
def test_get_int_invalid(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", "bar")
self.assertEqual(config.get_int("foo", 1234), 1234)

View File

@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
FORK_CTX = get_context("fork") FORK_CTX = get_context("fork")
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies") CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
PROCESS_CLASS = FORK_CTX.Process PROCESS_CLASS = FORK_CTX.Process

View File

@ -1,33 +0,0 @@
# Generated by Django 4.2.4 on 2023-08-31 10:42
from django.db import migrations, models
import authentik.policies.reputation.models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_reputation", "0004_reputationpolicy_authentik_p_policy__8f0d70_idx"),
]
operations = [
migrations.AddField(
model_name="reputation",
name="expires",
field=models.DateTimeField(
default=authentik.policies.reputation.models.reputation_expiry
),
),
migrations.AddField(
model_name="reputation",
name="expiring",
field=models.BooleanField(default=True),
),
migrations.AlterModelOptions(
name="reputation",
options={
"verbose_name": "Reputation Score",
"verbose_name_plural": "Reputation Scores",
},
),
]

View File

@ -1,17 +1,13 @@
"""authentik reputation request policy""" """authentik reputation request policy"""
from datetime import timedelta
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.db import models
from django.db.models import Sum from django.db.models import Sum
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from structlog import get_logger from structlog import get_logger
from authentik.core.models import ExpiringModel
from authentik.lib.config import CONFIG
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.http import get_client_ip from authentik.lib.utils.http import get_client_ip
from authentik.policies.models import Policy from authentik.policies.models import Policy
@ -21,11 +17,6 @@ LOGGER = get_logger()
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/" CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
def reputation_expiry():
"""Reputation expiry"""
return now() + timedelta(seconds=CONFIG.get_int("reputation.expiry"))
class ReputationPolicy(Policy): class ReputationPolicy(Policy):
"""Return true if request IP/target username's score is below a certain threshold""" """Return true if request IP/target username's score is below a certain threshold"""
@ -68,7 +59,7 @@ class ReputationPolicy(Policy):
verbose_name_plural = _("Reputation Policies") verbose_name_plural = _("Reputation Policies")
class Reputation(ExpiringModel, SerializerModel): class Reputation(SerializerModel):
"""Reputation for user and or IP.""" """Reputation for user and or IP."""
reputation_uuid = models.UUIDField(primary_key=True, unique=True, default=uuid4) reputation_uuid = models.UUIDField(primary_key=True, unique=True, default=uuid4)
@ -78,8 +69,6 @@ class Reputation(ExpiringModel, SerializerModel):
ip_geo_data = models.JSONField(default=dict) ip_geo_data = models.JSONField(default=dict)
score = models.BigIntegerField(default=0) score = models.BigIntegerField(default=0)
expires = models.DateTimeField(default=reputation_expiry)
updated = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now_add=True)
@property @property
@ -92,6 +81,4 @@ class Reputation(ExpiringModel, SerializerModel):
return f"Reputation {self.identifier}/{self.ip} @ {self.score}" return f"Reputation {self.identifier}/{self.ip} @ {self.score}"
class Meta: class Meta:
verbose_name = _("Reputation Score")
verbose_name_plural = _("Reputation Scores")
unique_together = ("identifier", "ip") unique_together = ("identifier", "ip")

View File

@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
from authentik.stages.identification.signals import identification_failed from authentik.stages.identification.signals import identification_failed
LOGGER = get_logger() LOGGER = get_logger()
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation") CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
def update_score(request: HttpRequest, identifier: str, amount: int): def update_score(request: HttpRequest, identifier: str, amount: int):

View File

@ -1,6 +1,6 @@
"""id_token utils""" """id_token utils"""
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
@ -57,7 +57,7 @@ class IDToken:
# Subject, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2 # Subject, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2
sub: Optional[str] = None sub: Optional[str] = None
# Audience, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.3 # Audience, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.3
aud: Optional[Union[str, list[str]]] = None aud: Optional[str] = None
# Expiration time, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4 # Expiration time, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4
exp: Optional[int] = None exp: Optional[int] = None
# Issued at, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6 # Issued at, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6

View File

@ -2,7 +2,6 @@
import base64 import base64
import binascii import binascii
import json import json
from dataclasses import asdict
from functools import cached_property from functools import cached_property
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional from typing import Any, Optional
@ -359,7 +358,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel):
@id_token.setter @id_token.setter
def id_token(self, value: IDToken): def id_token(self, value: IDToken):
self.token = value.to_access_token(self.provider) self.token = value.to_access_token(self.provider)
self._id_token = json.dumps(asdict(value)) self._id_token = json.dumps(value.to_dict())
@property @property
def at_hash(self): def at_hash(self):
@ -401,7 +400,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
@id_token.setter @id_token.setter
def id_token(self, value: IDToken): def id_token(self, value: IDToken):
self._id_token = json.dumps(asdict(value)) self._id_token = json.dumps(value.to_dict())
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:

View File

@ -151,14 +151,6 @@ class TestTokenClientCredentials(OAuthTestCase):
) )
self.assertEqual(jwt["given_name"], self.user.name) self.assertEqual(jwt["given_name"], self.user.name)
self.assertEqual(jwt["preferred_username"], self.user.username) self.assertEqual(jwt["preferred_username"], self.user.username)
jwt = decode(
body["id_token"],
key=self.provider.signing_key.public_key,
algorithms=[alg],
audience=self.provider.client_id,
)
self.assertEqual(jwt["given_name"], self.user.name)
self.assertEqual(jwt["preferred_username"], self.user.username)
def test_successful_password(self): def test_successful_password(self):
"""test successful (password grant)""" """test successful (password grant)"""

View File

@ -375,9 +375,7 @@ class AuthorizationFlowInitView(PolicyAccessView):
): ):
self.request.session[SESSION_KEY_LAST_LOGIN_UID] = login_uid self.request.session[SESSION_KEY_LAST_LOGIN_UID] = login_uid
return self.handle_no_permission() return self.handle_no_permission()
scope_descriptions = UserInfoView().get_scope_descriptions( scope_descriptions = UserInfoView().get_scope_descriptions(self.params.scope)
self.params.scope, self.params.provider
)
# Regardless, we start the planner and return to it # Regardless, we start the planner and return to it
planner = FlowPlanner(self.provider.authorization_flow) planner = FlowPlanner(self.provider.authorization_flow)
planner.allow_empty_flows = True planner.allow_empty_flows = True

View File

@ -55,7 +55,7 @@ def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]:
if not app: if not app:
return None return None
scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider) scope_descriptions = UserInfoView().get_scope_descriptions(token.scope)
planner = FlowPlanner(token.provider.authorization_flow) planner = FlowPlanner(token.provider.authorization_flow)
planner.allow_empty_flows = True planner.allow_empty_flows = True
try: try:

View File

@ -40,14 +40,10 @@ class UserInfoView(View):
token: Optional[RefreshToken] token: Optional[RefreshToken]
def get_scope_descriptions( def get_scope_descriptions(self, scopes: list[str]) -> list[PermissionDict]:
self, scopes: list[str], provider: OAuth2Provider
) -> list[PermissionDict]:
"""Get a list of all Scopes's descriptions""" """Get a list of all Scopes's descriptions"""
scope_descriptions = [] scope_descriptions = []
for scope in ScopeMapping.objects.filter(scope_name__in=scopes, provider=provider).order_by( for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by("scope_name"):
"scope_name"
):
scope_descriptions.append(PermissionDict(id=scope.scope_name, name=scope.description)) scope_descriptions.append(PermissionDict(id=scope.scope_name, name=scope.description))
# GitHub Compatibility Scopes are handled differently, since they required custom paths # GitHub Compatibility Scopes are handled differently, since they required custom paths
# Hence they don't exist as Scope objects # Hence they don't exist as Scope objects

View File

@ -59,9 +59,7 @@ class ProxyProviderSerializer(ProviderSerializer):
attrs.get("mode", ProxyMode.PROXY) == ProxyMode.PROXY attrs.get("mode", ProxyMode.PROXY) == ProxyMode.PROXY
and attrs.get("internal_host", "") == "" and attrs.get("internal_host", "") == ""
): ):
raise ValidationError( raise ValidationError(_("Internal host cannot be empty when forward auth is disabled."))
{"internal_host": _("Internal host cannot be empty when forward auth is disabled.")}
)
return attrs return attrs
def create(self, validated_data: dict): def create(self, validated_data: dict):

View File

@ -69,7 +69,7 @@ class ProxyProviderTests(APITestCase):
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{"internal_host": ["Internal host cannot be empty when forward auth is disabled."]}, {"non_field_errors": ["Internal host cannot be empty when forward auth is disabled."]},
) )
def test_create_defaults(self): def test_create_defaults(self):

View File

@ -68,7 +68,7 @@ class SCIMClient(Generic[T, SchemaType]):
"""Get Service provider config""" """Get Service provider config"""
default_config = ServiceProviderConfiguration.default() default_config = ServiceProviderConfiguration.default()
try: try:
return ServiceProviderConfiguration.model_validate( return ServiceProviderConfiguration.parse_obj(
self._request("GET", "/ServiceProviderConfig") self._request("GET", "/ServiceProviderConfig")
) )
except (ValidationError, SCIMRequestException) as exc: except (ValidationError, SCIMRequestException) as exc:

View File

@ -74,7 +74,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
if not raw_scim_group: if not raw_scim_group:
raise StopSync(ValueError("No group mappings configured"), obj) raise StopSync(ValueError("No group mappings configured"), obj)
try: try:
scim_group = SCIMGroupSchema.model_validate(delete_none_values(raw_scim_group)) scim_group = SCIMGroupSchema.parse_obj(delete_none_values(raw_scim_group))
except ValidationError as exc: except ValidationError as exc:
raise StopSync(exc, obj) from exc raise StopSync(exc, obj) from exc
if not scim_group.externalId: if not scim_group.externalId:
@ -99,8 +99,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
response = self._request( response = self._request(
"POST", "POST",
"/Groups", "/Groups",
json=scim_group.model_dump( data=scim_group.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )
@ -114,8 +113,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
return self._request( return self._request(
"PUT", "PUT",
f"/Groups/{scim_group.id}", f"/Groups/{scim_group.id}",
json=scim_group.model_dump( data=scim_group.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )
@ -162,13 +160,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
*ops: PatchOperation, *ops: PatchOperation,
): ):
req = PatchRequest(Operations=ops) req = PatchRequest(Operations=ops)
self._request( self._request("PATCH", f"/Groups/{group_id}", data=req.json())
"PATCH",
f"/Groups/{group_id}",
json=req.model_dump(
mode="json",
),
)
def _patch_add_users(self, group: Group, users_set: set[int]): def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group""" """Add users in users_set to group"""

View File

@ -52,7 +52,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
class PatchRequest(BasePatchRequest): class PatchRequest(BasePatchRequest):
"""PatchRequest which correctly sets schemas""" """PatchRequest which correctly sets schemas"""
schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) schemas: tuple[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]
class SCIMError(BaseSCIMError): class SCIMError(BaseSCIMError):

View File

@ -64,7 +64,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
if not raw_scim_user: if not raw_scim_user:
raise StopSync(ValueError("No user mappings configured"), obj) raise StopSync(ValueError("No user mappings configured"), obj)
try: try:
scim_user = SCIMUserSchema.model_validate(delete_none_values(raw_scim_user)) scim_user = SCIMUserSchema.parse_obj(delete_none_values(raw_scim_user))
except ValidationError as exc: except ValidationError as exc:
raise StopSync(exc, obj) from exc raise StopSync(exc, obj) from exc
if not scim_user.externalId: if not scim_user.externalId:
@ -77,8 +77,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
response = self._request( response = self._request(
"POST", "POST",
"/Users", "/Users",
json=scim_user.model_dump( data=scim_user.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )
@ -91,8 +90,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
self._request( self._request(
"PUT", "PUT",
f"/Users/{connection.id}", f"/Users/{connection.id}",
json=scim_user.model_dump( data=scim_user.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )

View File

@ -47,7 +47,6 @@ class SCIMMembershipTests(TestCase):
def test_member_add(self): def test_member_add(self):
"""Test member add""" """Test member add"""
config = ServiceProviderConfiguration.default() config = ServiceProviderConfiguration.default()
# pylint: disable=assigning-non-slot
config.patch.supported = True config.patch.supported = True
user_scim_id = generate_id() user_scim_id = generate_id()
group_scim_id = generate_id() group_scim_id = generate_id()
@ -61,7 +60,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.post( mocker.post(
"https://localhost/Users", "https://localhost/Users",
@ -105,7 +104,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.patch( mocker.patch(
f"https://localhost/Groups/{group_scim_id}", f"https://localhost/Groups/{group_scim_id}",
@ -132,7 +131,6 @@ class SCIMMembershipTests(TestCase):
def test_member_remove(self): def test_member_remove(self):
"""Test member remove""" """Test member remove"""
config = ServiceProviderConfiguration.default() config = ServiceProviderConfiguration.default()
# pylint: disable=assigning-non-slot
config.patch.supported = True config.patch.supported = True
user_scim_id = generate_id() user_scim_id = generate_id()
group_scim_id = generate_id() group_scim_id = generate_id()
@ -146,7 +144,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.post( mocker.post(
"https://localhost/Users", "https://localhost/Users",
@ -190,7 +188,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.patch( mocker.patch(
f"https://localhost/Groups/{group_scim_id}", f"https://localhost/Groups/{group_scim_id}",
@ -217,7 +215,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.patch( mocker.patch(
f"https://localhost/Groups/{group_scim_id}", f"https://localhost/Groups/{group_scim_id}",

View File

@ -44,11 +44,7 @@ def config_loggers(*args, **kwargs):
def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs): def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs):
"""Log task_id after it was published""" """Log task_id after it was published"""
info = headers if "task" in headers else body info = headers if "task" in headers else body
LOGGER.info( LOGGER.info("Task published", task_id=info.get("id", ""), task_name=info.get("task", ""))
"Task published",
task_id=info.get("id", "").replace("-", ""),
task_name=info.get("task", ""),
)
@task_prerun.connect @task_prerun.connect
@ -63,9 +59,7 @@ def task_prerun_hook(task_id: str, task, *args, **kwargs):
def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs): def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs):
"""Log task_id on worker""" """Log task_id on worker"""
CTX_TASK_ID.set(...) CTX_TASK_ID.set(...)
LOGGER.info( LOGGER.info("Task finished", task_id=task_id, task_name=task.__name__, state=state)
"Task finished", task_id=task_id.replace("-", ""), task_name=task.__name__, state=state
)
@task_failure.connect @task_failure.connect

View File

@ -2,7 +2,7 @@
from functools import lru_cache from functools import lru_cache
from uuid import uuid4 from uuid import uuid4
from psycopg import connect from psycopg2 import connect
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -30,7 +30,7 @@ def get_install_id_raw():
user=CONFIG.get("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=CONFIG.get_int("postgresql.port"), port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.get("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),

View File

@ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False):
_redis_url = ( _redis_url = (
f"{_redis_protocol_prefix}:" f"{_redis_protocol_prefix}:"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{CONFIG.get_int('redis.port')}" f"{int(CONFIG.get('redis.port'))}"
) )
CACHES = { CACHES = {
"default": { "default": {
"BACKEND": "django_redis.cache.RedisCache", "BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300), "TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache", "KEY_PREFIX": "authentik_cache",
} }
@ -274,7 +274,7 @@ DATABASES = {
"NAME": CONFIG.get("postgresql.name"), "NAME": CONFIG.get("postgresql.name"),
"USER": CONFIG.get("postgresql.user"), "USER": CONFIG.get("postgresql.user"),
"PASSWORD": CONFIG.get("postgresql.password"), "PASSWORD": CONFIG.get("postgresql.password"),
"PORT": CONFIG.get_int("postgresql.port"), "PORT": int(CONFIG.get("postgresql.port")),
"SSLMODE": CONFIG.get("postgresql.sslmode"), "SSLMODE": CONFIG.get("postgresql.sslmode"),
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
"SSLCERT": CONFIG.get("postgresql.sslcert"), "SSLCERT": CONFIG.get("postgresql.sslcert"),
@ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
# loads the config directly from CONFIG # loads the config directly from CONFIG
# See authentik/stages/email/models.py, line 105 # See authentik/stages/email/models.py, line 105
EMAIL_HOST = CONFIG.get("email.host") EMAIL_HOST = CONFIG.get("email.host")
EMAIL_PORT = CONFIG.get_int("email.port") EMAIL_PORT = int(CONFIG.get("email.port"))
EMAIL_HOST_USER = CONFIG.get("email.username") EMAIL_HOST_USER = CONFIG.get("email.username")
EMAIL_HOST_PASSWORD = CONFIG.get("email.password") EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
EMAIL_TIMEOUT = CONFIG.get_int("email.timeout") EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
DEFAULT_FROM_EMAIL = CONFIG.get("email.from") DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
SERVER_EMAIL = DEFAULT_FROM_EMAIL SERVER_EMAIL = DEFAULT_FROM_EMAIL
EMAIL_SUBJECT_PREFIX = "[authentik] " EMAIL_SUBJECT_PREFIX = "[authentik] "
@ -402,7 +402,6 @@ LOG_PRE_CHAIN = [
structlog.stdlib.add_logger_name, structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(), structlog.processors.TimeStamper(),
structlog.processors.StackInfoRenderer(), structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
] ]
LOGGING = { LOGGING = {

View File

@ -44,11 +44,7 @@ class LDAPSourceSerializer(SourceSerializer):
sources = sources.exclude(pk=self.instance.pk) sources = sources.exclude(pk=self.instance.pk)
if sources.exists(): if sources.exists():
raise ValidationError( raise ValidationError(
{ "Only a single LDAP Source with password synchronization is allowed"
"sync_users_password": (
"Only a single LDAP Source with password synchronization is allowed"
)
}
) )
return super().validate(attrs) return super().validate(attrs)

View File

@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
types_only=False, types_only=False,
get_operational_attributes=False, get_operational_attributes=False,
controls=None, controls=None,
paged_size=CONFIG.get_int("ldap.page_size", 50), paged_size=int(CONFIG.get("ldap.page_size", 50)),
paged_criticality=False, paged_criticality=False,
): ):
"""Search in pages, returns each page""" """Search in pages, returns each page"""

View File

@ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
signatures = [] signatures = []
for page in sync_inst.get_objects(): for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")))
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
signatures.append(page_sync) signatures.append(page_sync)
return signatures return signatures
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
@CELERY_APP.task( @CELERY_APP.task(
bind=True, bind=True,
base=MonitoredTask, base=MonitoredTask,
soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
) )
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source""" """Synchronization of an LDAP Source"""
self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source: if not source:
# Because the source couldn't be found, we don't have a UID # Because the source couldn't be found, we don't have a UID

View File

@ -62,8 +62,7 @@ class OAuthSourceSerializer(SourceSerializer):
well_known_config = session.get(well_known) well_known_config = session.get(well_known)
well_known_config.raise_for_status() well_known_config.raise_for_status()
except RequestException as exc: except RequestException as exc:
text = exc.response.text if exc.response else str(exc) raise ValidationError(exc.response.text)
raise ValidationError({"oidc_well_known_url": text})
config = well_known_config.json() config = well_known_config.json()
try: try:
attrs["authorization_url"] = config["authorization_endpoint"] attrs["authorization_url"] = config["authorization_endpoint"]
@ -71,9 +70,7 @@ class OAuthSourceSerializer(SourceSerializer):
attrs["profile_url"] = config["userinfo_endpoint"] attrs["profile_url"] = config["userinfo_endpoint"]
attrs["oidc_jwks_url"] = config["jwks_uri"] attrs["oidc_jwks_url"] = config["jwks_uri"]
except (IndexError, KeyError) as exc: except (IndexError, KeyError) as exc:
raise ValidationError( raise ValidationError(f"Invalid well-known configuration: {exc}")
{"oidc_well_known_url": f"Invalid well-known configuration: {exc}"}
)
jwks_url = attrs.get("oidc_jwks_url") jwks_url = attrs.get("oidc_jwks_url")
if jwks_url and jwks_url != "": if jwks_url and jwks_url != "":
@ -81,8 +78,7 @@ class OAuthSourceSerializer(SourceSerializer):
jwks_config = session.get(jwks_url) jwks_config = session.get(jwks_url)
jwks_config.raise_for_status() jwks_config.raise_for_status()
except RequestException as exc: except RequestException as exc:
text = exc.response.text if exc.response else str(exc) raise ValidationError(exc.response.text)
raise ValidationError({"jwks_url": text})
config = jwks_config.json() config = jwks_config.json()
attrs["oidc_jwks"] = config attrs["oidc_jwks"] = config

View File

@ -30,7 +30,7 @@ def check_plex_token(self: MonitoredTask, source_slug: int):
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."])) self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."]))
except RequestException as exc: except RequestException as exc:
error = exception_to_string(exc) error = exception_to_string(exc)
if len(source.plex_token) > 0: if len(source.plex_token) < 1:
error = error.replace(source.plex_token, "$PLEX_TOKEN") error = error.replace(source.plex_token, "$PLEX_TOKEN")
self.set_status( self.set_status(
TaskResult( TaskResult(

View File

@ -18,12 +18,7 @@ class AuthenticatorStaticStageSerializer(StageSerializer):
class Meta: class Meta:
model = AuthenticatorStaticStage model = AuthenticatorStaticStage
fields = StageSerializer.Meta.fields + [ fields = StageSerializer.Meta.fields + ["configure_flow", "friendly_name", "token_count"]
"configure_flow",
"friendly_name",
"token_count",
"token_length",
]
class AuthenticatorStaticStageViewSet(UsedByMixin, ModelViewSet): class AuthenticatorStaticStageViewSet(UsedByMixin, ModelViewSet):

View File

@ -1,22 +0,0 @@
# Generated by Django 4.2.4 on 2023-08-17 17:34
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_stages_authenticator_static", "0006_authenticatorstaticstage_friendly_name"),
]
operations = [
migrations.AddField(
model_name="authenticatorstaticstage",
name="token_length",
field=models.PositiveIntegerField(default=12),
),
migrations.AlterField(
model_name="authenticatorstaticstage",
name="token_count",
field=models.PositiveIntegerField(default=6),
),
]

View File

@ -13,8 +13,7 @@ from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage): class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage):
"""Generate static tokens for the user as a backup.""" """Generate static tokens for the user as a backup."""
token_count = models.PositiveIntegerField(default=6) token_count = models.IntegerField(default=6)
token_length = models.PositiveIntegerField(default=12)
@property @property
def serializer(self) -> type[BaseSerializer]: def serializer(self) -> type[BaseSerializer]:

View File

@ -5,7 +5,6 @@ from rest_framework.fields import CharField, ListField
from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge
from authentik.flows.stage import ChallengeStageView from authentik.flows.stage import ChallengeStageView
from authentik.lib.generators import generate_id
from authentik.stages.authenticator_static.models import AuthenticatorStaticStage from authentik.stages.authenticator_static.models import AuthenticatorStaticStage
SESSION_STATIC_DEVICE = "static_device" SESSION_STATIC_DEVICE = "static_device"
@ -51,9 +50,7 @@ class AuthenticatorStaticStageView(ChallengeStageView):
device = StaticDevice(user=user, confirmed=False, name="Static Token") device = StaticDevice(user=user, confirmed=False, name="Static Token")
tokens = [] tokens = []
for _ in range(0, stage.token_count): for _ in range(0, stage.token_count):
tokens.append( tokens.append(StaticToken(device=device, token=StaticToken.random_token()))
StaticToken(device=device, token=generate_id(length=stage.token_length))
)
self.request.session[SESSION_STATIC_DEVICE] = device self.request.session[SESSION_STATIC_DEVICE] = device
self.request.session[SESSION_STATIC_TOKENS] = tokens self.request.session[SESSION_STATIC_TOKENS] = tokens
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)

View File

@ -1,4 +1,5 @@
"""Validation stage challenge checking""" """Validation stage challenge checking"""
from json import dumps, loads
from typing import Optional from typing import Optional
from urllib.parse import urlencode from urllib.parse import urlencode
@ -16,6 +17,7 @@ from webauthn.authentication.generate_authentication_options import generate_aut
from webauthn.authentication.verify_authentication_response import verify_authentication_response from webauthn.authentication.verify_authentication_response import verify_authentication_response
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
from webauthn.helpers.exceptions import InvalidAuthenticationResponse from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from webauthn.helpers.options_to_json import options_to_json
from webauthn.helpers.structs import AuthenticationCredential from webauthn.helpers.structs import AuthenticationCredential
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
@ -66,12 +68,7 @@ def get_webauthn_challenge_without_user(
) )
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return authentication_options.model_dump( return loads(options_to_json(authentication_options))
mode="json",
by_alias=True,
exclude_unset=False,
exclude_none=True,
)
def get_webauthn_challenge( def get_webauthn_challenge(
@ -96,12 +93,7 @@ def get_webauthn_challenge(
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return authentication_options.model_dump( return loads(options_to_json(authentication_options))
mode="json",
by_alias=True,
exclude_unset=False,
exclude_none=True,
)
def select_challenge(request: HttpRequest, device: Device): def select_challenge(request: HttpRequest, device: Device):
@ -152,7 +144,7 @@ def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -
try: try:
authentication_verification = verify_authentication_response( authentication_verification = verify_authentication_response(
credential=AuthenticationCredential.model_validate(data), credential=AuthenticationCredential.parse_raw(dumps(data)),
expected_challenge=challenge, expected_challenge=challenge,
expected_rp_id=get_rp_id(request), expected_rp_id=get_rp_id(request),
expected_origin=get_origin(request), expected_origin=get_origin(request),

View File

@ -234,12 +234,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"assertionClientExtensions": "{}", "assertionClientExtensions": "{}",
"response": { "response": {
"clientDataJSON": ( "clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uZ2V0IiwiY2hhbGxlbmdlIjoiZzk4STUxbVF2WlhvN" (
"Wx4TGZockQyemZvbGhaYkxSeUNncWtrWWFwMWp3U2FKMTNCZ3VvSldDRjlfTGczQW" "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0IiwiY2hhbGxlbmdlIjoiZzk4STUxbVF2WlhvN"
"dPNFdoLUJxYTU1NkpFMjBvS3NZYmw2UkEiLCJvcmlnaW4iOiJodHRwOi8vbG9jYWx" "Wx4TGZockQyemZvbGhaYkxSeUNncWtrWWFwMWp3U2FKMTNCZ3VvSldDRjlfTGczQW"
"ob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmFsc2UsIm90aGVyX2tleXNfY2FuX2Jl" "dPNFdoLUJxYTU1NkpFMjBvS3NZYmw2UkEiLCJvcmlnaW4iOiJodHRwOi8vbG9jYWx"
"X2FkZGVkX2hlcmUiOiJkbyBub3QgY29tcGFyZSBjbGllbnREYXRhSlNPTiBhZ2Fpb" "ob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmFsc2UsIm90aGVyX2tleXNfY2FuX2Jl"
"nN0IGEgdGVtcGxhdGUuIFNlZSBodHRwczovL2dvby5nbC95YWJQZXgifQ==" "X2FkZGVkX2hlcmUiOiJkbyBub3QgY29tcGFyZSBjbGllbnREYXRhSlNPTiBhZ2Fpb"
"nN0IGEgdGVtcGxhdGUuIFNlZSBodHRwczovL2dvby5nbC95YWJQZXgifQ=="
),
), ),
"signature": ( "signature": (
"MEQCIFNlrHf9ablJAalXLWkrqvHB8oIu8kwvRpH3X3rbJVpI" "MEQCIFNlrHf9ablJAalXLWkrqvHB8oIu8kwvRpH3X3rbJVpI"
@ -304,12 +306,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"assertionClientExtensions": "{}", "assertionClientExtensions": "{}",
"response": { "response": {
"clientDataJSON": ( "clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uZ2V0IiwiY2hhbGxlbmdlIjoiZzk4STUxbVF2Wlhv" (
"NWx4TGZockQyemZvbGhaYkxSeUNncWtrWWFwMWp3U2FKMTNCZ3VvSldDRjlfTGcz" "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0IiwiY2hhbGxlbmdlIjoiZzk4STUxbVF2Wlhv"
"QWdPNFdoLUJxYTU1NkpFMjBvS3NZYmw2UkEiLCJvcmlnaW4iOiJodHRwOi8vbG9j" "NWx4TGZockQyemZvbGhaYkxSeUNncWtrWWFwMWp3U2FKMTNCZ3VvSldDRjlfTGcz"
"YWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmFsc2UsIm90aGVyX2tleXNfY2Fu" "QWdPNFdoLUJxYTU1NkpFMjBvS3NZYmw2UkEiLCJvcmlnaW4iOiJodHRwOi8vbG9j"
"X2JlX2FkZGVkX2hlcmUiOiJkbyBub3QgY29tcGFyZSBjbGllbnREYXRhSlNPTiBh" "YWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmFsc2UsIm90aGVyX2tleXNfY2Fu"
"Z2FpbnN0IGEgdGVtcGxhdGUuIFNlZSBodHRwczovL2dvby5nbC95YWJQZXgifQ==" "X2JlX2FkZGVkX2hlcmUiOiJkbyBub3QgY29tcGFyZSBjbGllbnREYXRhSlNPTiBh"
"Z2FpbnN0IGEgdGVtcGxhdGUuIFNlZSBodHRwczovL2dvby5nbC95YWJQZXgifQ=="
),
), ),
"signature": ( "signature": (
"MEQCIFNlrHf9ablJAalXLWkrqvHB8oIu8kwvRpH3X3rbJVpI" "MEQCIFNlrHf9ablJAalXLWkrqvHB8oIu8kwvRpH3X3rbJVpI"

View File

@ -1,10 +1,13 @@
"""WebAuthn stage""" """WebAuthn stage"""
from json import dumps, loads
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict from django.http.request import QueryDict
from rest_framework.fields import CharField, JSONField from rest_framework.fields import CharField, JSONField
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
from webauthn.helpers.exceptions import InvalidRegistrationResponse from webauthn.helpers.exceptions import InvalidRegistrationResponse
from webauthn.helpers.options_to_json import options_to_json
from webauthn.helpers.structs import ( from webauthn.helpers.structs import (
AuthenticatorSelectionCriteria, AuthenticatorSelectionCriteria,
PublicKeyCredentialCreationOptions, PublicKeyCredentialCreationOptions,
@ -52,7 +55,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
try: try:
registration: VerifiedRegistration = verify_registration_response( registration: VerifiedRegistration = verify_registration_response(
credential=RegistrationCredential.model_validate(response), credential=RegistrationCredential.parse_raw(dumps(response)),
expected_challenge=challenge, expected_challenge=challenge,
expected_rp_id=get_rp_id(self.request), expected_rp_id=get_rp_id(self.request),
expected_origin=get_origin(self.request), expected_origin=get_origin(self.request),
@ -105,12 +108,7 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
return AuthenticatorWebAuthnChallenge( return AuthenticatorWebAuthnChallenge(
data={ data={
"type": ChallengeTypes.NATIVE.value, "type": ChallengeTypes.NATIVE.value,
"registration": registration_options.model_dump( "registration": loads(options_to_json(registration_options)),
mode="json",
by_alias=True,
exclude_unset=False,
exclude_none=True,
),
} }
) )

View File

@ -108,12 +108,12 @@ class EmailStage(Stage):
CONFIG.refresh("email.password") CONFIG.refresh("email.password")
return self.backend_class( return self.backend_class(
host=CONFIG.get("email.host"), host=CONFIG.get("email.host"),
port=CONFIG.get_int("email.port"), port=int(CONFIG.get("email.port")),
username=CONFIG.get("email.username"), username=CONFIG.get("email.username"),
password=CONFIG.get("email.password"), password=CONFIG.get("email.password"),
use_tls=CONFIG.get_bool("email.use_tls", False), use_tls=CONFIG.get_bool("email.use_tls", False),
use_ssl=CONFIG.get_bool("email.use_ssl", False), use_ssl=CONFIG.get_bool("email.use_ssl", False),
timeout=CONFIG.get_int("email.timeout"), timeout=int(CONFIG.get("email.timeout")),
) )
return self.backend_class( return self.backend_class(
host=self.host, host=self.host,

View File

@ -12,7 +12,7 @@ from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
from authentik.flows.models import FlowDesignation, FlowToken from authentik.flows.models import FlowToken
from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import ChallengeStageView from authentik.flows.stage import ChallengeStageView
from authentik.flows.views.executor import QS_KEY_TOKEN from authentik.flows.views.executor import QS_KEY_TOKEN
@ -82,11 +82,6 @@ class EmailStageView(ChallengeStageView):
"""Helper function that sends the actual email. Implies that you've """Helper function that sends the actual email. Implies that you've
already checked that there is a pending user.""" already checked that there is a pending user."""
pending_user = self.get_pending_user() pending_user = self.get_pending_user()
if not pending_user.pk and self.executor.flow.designation == FlowDesignation.RECOVERY:
# Pending user does not have a primary key, and we're in a recovery flow,
# which means the user entered an invalid identifier, so we pretend to send the
# email, to not disclose if the user exists
return
email = self.executor.plan.context.get(PLAN_CONTEXT_EMAIL_OVERRIDE, None) email = self.executor.plan.context.get(PLAN_CONTEXT_EMAIL_OVERRIDE, None)
if not email: if not email:
email = pending_user.email email = pending_user.email

View File

@ -5,20 +5,18 @@ from unittest.mock import MagicMock, PropertyMock, patch
from django.core import mail from django.core import mail
from django.core.mail.backends.locmem import EmailBackend from django.core.mail.backends.locmem import EmailBackend
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.flows.markers import StageMarker from authentik.flows.markers import StageMarker
from authentik.flows.models import FlowDesignation, FlowStageBinding from authentik.flows.models import FlowDesignation, FlowStageBinding
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
class TestEmailStageSending(FlowTestCase): class TestEmailStageSending(APITestCase):
"""Email tests""" """Email tests"""
def setUp(self): def setUp(self):
@ -46,13 +44,6 @@ class TestEmailStageSending(FlowTestCase):
): ):
response = self.client.post(url) response = self.client.post(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
self.flow,
response_errors={
"non_field_errors": [{"string": "email-sent", "code": "email-sent"}]
},
)
self.assertEqual(len(mail.outbox), 1) self.assertEqual(len(mail.outbox), 1)
self.assertEqual(mail.outbox[0].subject, "authentik") self.assertEqual(mail.outbox[0].subject, "authentik")
events = Event.objects.filter(action=EventAction.EMAIL_SENT) events = Event.objects.filter(action=EventAction.EMAIL_SENT)
@ -63,32 +54,6 @@ class TestEmailStageSending(FlowTestCase):
self.assertEqual(event.context["to_email"], [self.user.email]) self.assertEqual(event.context["to_email"], [self.user.email])
self.assertEqual(event.context["from_email"], "system@authentik.local") self.assertEqual(event.context["from_email"], "system@authentik.local")
def test_pending_fake_user(self):
"""Test with pending (fake) user"""
self.flow.designation = FlowDesignation.RECOVERY
self.flow.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = User(username=generate_id())
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
with patch(
"authentik.stages.email.models.EmailStage.backend_class",
PropertyMock(return_value=EmailBackend),
):
response = self.client.post(url)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
self.flow,
response_errors={
"non_field_errors": [{"string": "email-sent", "code": "email-sent"}]
},
)
self.assertEqual(len(mail.outbox), 0)
def test_send_error(self): def test_send_error(self):
"""Test error during sending (sending will be retried)""" """Test error during sending (sending will be retried)"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])

View File

@ -118,12 +118,8 @@ class IdentificationChallengeResponse(ChallengeResponse):
username=uid_field, username=uid_field,
email=uid_field, email=uid_field,
) )
self.pre_user = self.stage.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
if not current_stage.show_matched_user: if not current_stage.show_matched_user:
self.stage.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = uid_field self.stage.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = uid_field
if self.stage.executor.flow.designation == FlowDesignation.RECOVERY:
# When used in a recovery flow, always continue to not disclose if a user exists
return attrs
raise ValidationError("Failed to authenticate.") raise ValidationError("Failed to authenticate.")
self.pre_user = pre_user self.pre_user = pre_user
if not current_stage.password_stage: if not current_stage.password_stage:

View File

@ -188,7 +188,7 @@ class TestIdentificationStage(FlowTestCase):
], ],
) )
def test_link_recovery_flow(self): def test_recovery_flow(self):
"""Test that recovery flow is linked correctly""" """Test that recovery flow is linked correctly"""
flow = create_test_flow() flow = create_test_flow()
self.stage.recovery_flow = flow self.stage.recovery_flow = flow
@ -226,38 +226,6 @@ class TestIdentificationStage(FlowTestCase):
], ],
) )
def test_recovery_flow_invalid_user(self):
"""Test that an invalid user can proceed in a recovery flow"""
self.flow.designation = FlowDesignation.RECOVERY
self.flow.save()
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
)
self.assertStageResponse(
response,
self.flow,
component="ak-stage-identification",
user_fields=["email"],
password_fields=False,
show_source_labels=False,
primary_action="Continue",
sources=[
{
"challenge": {
"component": "xak-flow-redirect",
"to": "/source/oauth/login/test/",
"type": ChallengeTypes.REDIRECT.value,
},
"icon_url": "/static/authentik/sources/default.svg",
"name": "test",
}
],
)
form_data = {"uid_field": generate_id()}
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
response = self.client.post(url, form_data)
self.assertEqual(response.status_code, 200)
def test_api_validate(self): def test_api_validate(self):
"""Test API validation""" """Test API validation"""
self.assertTrue( self.assertTrue(

View File

@ -99,7 +99,6 @@ class TestUserLoginStage(FlowTestCase):
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session.save() session.save()
before_request = now()
response = self.client.get( response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
) )
@ -109,7 +108,7 @@ class TestUserLoginStage(FlowTestCase):
session_key = self.client.session.session_key session_key = self.client.session.session_key
session = AuthenticatedSession.objects.filter(session_key=session_key).first() session = AuthenticatedSession.objects.filter(session_key=session_key).first()
self.assertAlmostEqual( self.assertAlmostEqual(
session.expires.timestamp() - before_request.timestamp(), session.expires.timestamp() - now().timestamp(),
timedelta_from_string(self.stage.session_duration).total_seconds(), timedelta_from_string(self.stage.session_duration).total_seconds(),
delta=1, delta=1,
) )

View File

@ -36,7 +36,7 @@ class TenantSerializer(ModelSerializer):
if self.instance: if self.instance:
tenants = tenants.exclude(pk=self.instance.pk) tenants = tenants.exclude(pk=self.instance.pk)
if tenants.exists(): if tenants.exists():
raise ValidationError({"default": "Only a single Tenant can be set as default."}) raise ValidationError("Only a single Tenant can be set as default.")
return super().validate(attrs) return super().validate(attrs)
class Meta: class Meta:

View File

@ -2,12 +2,6 @@ version: 1
metadata: metadata:
name: Default - Events Transport & Rules name: Default - Events Transport & Rules
entries: entries:
# Run bootstrap blueprint first to ensure we have the group created
- model: authentik_blueprints.metaapplyblueprint
attrs:
identifiers:
path: system/bootstrap.yaml
required: false
- model: authentik_events.notificationtransport - model: authentik_events.notificationtransport
id: default-email-transport id: default-email-transport
attrs: attrs:
@ -22,7 +16,6 @@ entries:
name: default-local-transport name: default-local-transport
- model: authentik_core.group - model: authentik_core.group
id: group id: group
state: created
identifiers: identifiers:
name: authentik Admins name: authentik Admins

View File

@ -51,9 +51,6 @@ entries:
order: 20 order: 20
stage: !KeyOf default-authentication-password stage: !KeyOf default-authentication-password
target: !KeyOf flow target: !KeyOf flow
attrs:
re_evaluate_policies: true
id: default-authentication-flow-password-binding
model: authentik_flows.flowstagebinding model: authentik_flows.flowstagebinding
- identifiers: - identifiers:
order: 30 order: 30
@ -65,20 +62,3 @@ entries:
stage: !KeyOf default-authentication-login stage: !KeyOf default-authentication-login
target: !KeyOf flow target: !KeyOf flow
model: authentik_flows.flowstagebinding model: authentik_flows.flowstagebinding
- model: authentik_policies_expression.expressionpolicy
id: default-authentication-flow-password-optional
identifiers:
name: default-authentication-flow-password-stage
attrs:
expression: |
flow_plan = request.context.get("flow_plan")
if not flow_plan:
return True
# If the user does not have a backend attached to it, they haven't
# been authenticated yet and we need the password stage
return not hasattr(flow_plan.context.get("pending_user"), "backend")
- model: authentik_policies.policybinding
identifiers:
order: 10
target: !KeyOf default-authentication-flow-password-binding
policy: !KeyOf default-authentication-flow-password-optional

View File

@ -1,8 +1,8 @@
version: 1 version: 1
metadata: metadata:
name: Migration - Remove old prompt fields
labels: labels:
blueprints.goauthentik.io/description: Migrate to 2023.2, remove unused prompt fields blueprints.goauthentik.io/description: Migrate to 2023.2, remove unused prompt fields
name: Migration - Remove old prompt fields
entries: entries:
- model: authentik_stages_prompt.prompt - model: authentik_stages_prompt.prompt
identifiers: identifiers:

Some files were not shown because too many files have changed in this diff Show More