Compare commits

..

1 Commits

Author SHA1 Message Date
87bf75e51c add Registration closed note 2023-07-28 12:36:15 -05:00
598 changed files with 10473 additions and 21697 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2023.8.3 current_version = 2023.6.1
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@ -7,4 +7,3 @@ build/**
build_docs/** build_docs/**
Dockerfile Dockerfile
authentik/enterprise authentik/enterprise
blueprints/local

View File

@ -14,7 +14,7 @@ runs:
run: | run: |
pipx install poetry || true pipx install poetry || true
sudo apt update sudo apt update
sudo apt install -y libpq-dev openssl libxmlsec1-dev pkg-config gettext sudo apt install -y libxmlsec1-dev pkg-config gettext
- name: Setup python and restore poetry - name: Setup python and restore poetry
uses: actions/setup-python@v3 uses: actions/setup-python@v3
with: with:
@ -23,7 +23,7 @@ runs:
- name: Setup node - name: Setup node
uses: actions/setup-node@v3 uses: actions/setup-node@v3
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Setup dependencies - name: Setup dependencies

View File

@ -1,2 +0,0 @@
enabled: true
preservePullRequestTitle: true

View File

@ -8,8 +8,6 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "ci:" prefix: "ci:"
labels:
- dependencies
- package-ecosystem: gomod - package-ecosystem: gomod
directory: "/" directory: "/"
schedule: schedule:
@ -18,15 +16,11 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "core:" prefix: "core:"
labels:
- dependencies
- package-ecosystem: npm - package-ecosystem: npm
directory: "/web" directory: "/web"
schedule: schedule:
interval: daily interval: daily
time: "04:00" time: "04:00"
labels:
- dependencies
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "web:" prefix: "web:"
@ -38,18 +32,10 @@ updates:
patterns: patterns:
- "@babel/*" - "@babel/*"
- "babel-*" - "babel-*"
eslint:
patterns:
- "@typescript-eslint/eslint-*"
- "eslint"
- "eslint-*"
storybook: storybook:
patterns: patterns:
- "@storybook/*" - "@storybook/*"
- "*storybook*" - "*storybook*"
esbuild:
patterns:
- "@esbuild/*"
- package-ecosystem: npm - package-ecosystem: npm
directory: "/website" directory: "/website"
schedule: schedule:
@ -58,8 +44,6 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "website:" prefix: "website:"
labels:
- dependencies
groups: groups:
docusaurus: docusaurus:
patterns: patterns:
@ -72,8 +56,6 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "core:" prefix: "core:"
labels:
- dependencies
- package-ecosystem: docker - package-ecosystem: docker
directory: "/" directory: "/"
schedule: schedule:
@ -82,5 +64,3 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "core:" prefix: "core:"
labels:
- dependencies

View File

@ -1,19 +1,23 @@
<!-- <!--
👋 Hi there! Welcome. 👋 Hello there! Welcome.
Please check the Contributing guidelines: https://goauthentik.io/developer-docs/#how-can-i-contribute Please check the [Contributing guidelines](https://goauthentik.io/developer-docs/#how-can-i-contribute).
--> -->
## Details ## Details
<!-- - **Does this resolve an issue?**
Explain what this PR changes, what the rationale behind the change is, if any new requirements are introduced or any breaking changes caused by this PR. Resolves #
Ideally also link an Issue for context that this PR will close using `closes #` ## Changes
-->
REPLACE ME
--- ### New Features
- Adds feature which does x, y, and z.
### Breaking Changes
- Adds breaking change which causes \<issue\>.
## Checklist ## Checklist

View File

@ -88,8 +88,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
psql: psql:
- 11-alpine
- 12-alpine - 12-alpine
- 15-alpine
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup authentik env - name: Setup authentik env

View File

@ -120,9 +120,9 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Generate API - name: Generate API

View File

@ -14,10 +14,10 @@ jobs:
lint-eslint: lint-eslint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -30,10 +30,10 @@ jobs:
lint-build: lint-build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -46,10 +46,10 @@ jobs:
lint-prettier: lint-prettier:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -62,10 +62,10 @@ jobs:
lint-lit-analyse: lint-lit-analyse:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -94,10 +94,10 @@ jobs:
- ci-web-mark - ci-web-mark
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/

View File

@ -14,10 +14,10 @@ jobs:
lint-prettier: lint-prettier:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/
@ -28,10 +28,10 @@ jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/
@ -49,10 +49,10 @@ jobs:
- build - build
- build-docs-only - build-docs-only
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/

View File

@ -1,34 +0,0 @@
---
# See https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries
name: Cleanup cache after PR is closed
on:
pull_request:
types:
- closed
jobs:
cleanup:
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v3
- name: Cleanup
run: |
gh extension install actions/gh-actions-cache
REPO=${{ github.repository }}
BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge"
echo "Fetching list of cache key"
cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH -L 100 | cut -f 1 )
# Setting this to not fail the workflow while deleting cache keys.
set +e
echo "Deleting caches..."
for cacheKey in $cacheKeysForPR; do
gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm
done
echo "Done"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,61 +0,0 @@
---
name: authentik-compress-images
on:
push:
branches:
- main
paths:
- "**.jpg"
- "**.jpeg"
- "**.png"
- "**.webp"
pull_request:
paths:
- "**.jpg"
- "**.jpeg"
- "**.png"
- "**.webp"
workflow_dispatch:
jobs:
compress:
name: compress
runs-on: ubuntu-latest
# Don't run on forks. Token will not be available. Will run on main and open a PR anyway
if: |
github.repository == 'goauthentik/authentik' &&
(github.event_name != 'pull_request' ||
github.event.pull_request.head.repo.full_name == github.repository)
steps:
- id: generate_token
uses: tibdex/github-app-token@v1
with:
app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
- uses: actions/checkout@v3
with:
token: ${{ steps.generate_token.outputs.token }}
- name: Compress images
id: compress
uses: calibreapp/image-actions@main
with:
githubToken: ${{ steps.generate_token.outputs.token }}
compressOnly: ${{ github.event_name != 'pull_request' }}
- uses: peter-evans/create-pull-request@v5
if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}"
id: cpr
with:
token: ${{ steps.generate_token.outputs.token }}
title: "*: Auto compress images"
branch-suffix: timestamp
commit-messsage: "*: compress images"
body: ${{ steps.compress.outputs.markdown }}
delete-branch: true
signoff: true
- uses: peter-evans/enable-pull-request-automerge@v3
if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}"
with:
token: ${{ steps.generate_token.outputs.token }}
pull-request-number: ${{ steps.cpr.outputs.pull-request-number }}
merge-method: squash

View File

@ -1,31 +0,0 @@
name: authentik-publish-source-docs
on:
push:
branches:
- main
env:
POSTGRES_DB: authentik
POSTGRES_USER: authentik
POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77"
jobs:
publish-source-docs:
runs-on: ubuntu-latest
timeout-minutes: 120
steps:
- uses: actions/checkout@v3
- name: Setup authentik env
uses: ./.github/actions/setup
- name: generate docs
run: |
poetry run make migrate
poetry run ak build_source_docs
- name: Publish
uses: netlify/actions/cli@master
with:
args: deploy --dir=source_docs --prod
env:
NETLIFY_SITE_ID: eb246b7b-1d83-4f69-89f7-01a936b4ca59
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}

View File

@ -110,9 +110,9 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Build web - name: Build web

View File

@ -1,5 +1,4 @@
# Rename transifex pull requests to have a correct naming # Rename transifex pull requests to have a correct naming
# Also enables auto squash-merge
name: authentik-translation-transifex-rename name: authentik-translation-transifex-rename
on: on:
@ -38,8 +37,3 @@ jobs:
-H "X-GitHub-Api-Version: 2022-11-28" \ -H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} \ https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} \
-d "{\"title\":\"translate: ${{ steps.title.outputs.title }}\"}" -d "{\"title\":\"translate: ${{ steps.title.outputs.title }}\"}"
- uses: peter-evans/enable-pull-request-automerge@v3
with:
token: ${{ steps.generate_token.outputs.token }}
pull-request-number: ${{ github.event.pull_request.number }}
merge-method: squash

View File

@ -17,9 +17,9 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20.5" node-version: "20"
registry-url: "https://registry.npmjs.org" registry-url: "https://registry.npmjs.org"
- name: Generate API Client - name: Generate API Client
run: make gen-client-ts run: make gen-client-ts

1
.gitignore vendored
View File

@ -205,4 +205,3 @@ data/
# Local Netlify folder # Local Netlify folder
.netlify .netlify
.ruff_cache .ruff_cache
source_docs/

View File

@ -31,8 +31,7 @@
"!Format sequence", "!Format sequence",
"!Condition sequence", "!Condition sequence",
"!Env sequence", "!Env sequence",
"!Env scalar", "!Env scalar"
"!If sequence"
], ],
"typescript.preferences.importModuleSpecifier": "non-relative", "typescript.preferences.importModuleSpecifier": "non-relative",
"typescript.preferences.importModuleSpecifierEnding": "index", "typescript.preferences.importModuleSpecifierEnding": "index",

View File

@ -1,5 +1,5 @@
# Stage 1: Build website # Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/node:20.5 as website-builder FROM --platform=${BUILDPLATFORM} docker.io/node:20 as website-builder
COPY ./website /work/website/ COPY ./website /work/website/
COPY ./blueprints /work/blueprints/ COPY ./blueprints /work/blueprints/
@ -10,7 +10,7 @@ WORKDIR /work/website
RUN npm ci --include=dev && npm run build-docs-only RUN npm ci --include=dev && npm run build-docs-only
# Stage 2: Build webui # Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/node:20.5 as web-builder FROM --platform=${BUILDPLATFORM} docker.io/node:20 as web-builder
COPY ./web /work/web/ COPY ./web /work/web/
COPY ./website /work/website/ COPY ./website /work/website/
@ -20,7 +20,7 @@ WORKDIR /work/web
RUN npm ci --include=dev && npm run build RUN npm ci --include=dev && npm run build
# Stage 3: Poetry to requirements.txt export # Stage 3: Poetry to requirements.txt export
FROM docker.io/python:3.11.5-slim-bookworm AS poetry-locker FROM docker.io/python:3.11.4-slim-bullseye AS poetry-locker
WORKDIR /work WORKDIR /work
COPY ./pyproject.toml /work COPY ./pyproject.toml /work
@ -31,7 +31,7 @@ RUN pip install --no-cache-dir poetry && \
poetry export -f requirements.txt --dev --output requirements-dev.txt poetry export -f requirements.txt --dev --output requirements-dev.txt
# Stage 4: Build go proxy # Stage 4: Build go proxy
FROM docker.io/golang:1.21.0-bookworm AS go-builder FROM docker.io/golang:1.20.6-bullseye AS go-builder
WORKDIR /work WORKDIR /work
@ -39,13 +39,12 @@ COPY --from=web-builder /work/web/robots.txt /work/web/robots.txt
COPY --from=web-builder /work/web/security.txt /work/web/security.txt COPY --from=web-builder /work/web/security.txt /work/web/security.txt
COPY ./cmd /work/cmd COPY ./cmd /work/cmd
COPY ./authentik/lib /work/authentik/lib
COPY ./web/static.go /work/web/static.go COPY ./web/static.go /work/web/static.go
COPY ./internal /work/internal COPY ./internal /work/internal
COPY ./go.mod /work/go.mod COPY ./go.mod /work/go.mod
COPY ./go.sum /work/go.sum COPY ./go.sum /work/go.sum
RUN go build -o /work/bin/authentik ./cmd/server/ RUN go build -o /work/authentik ./cmd/server/
# Stage 5: MaxMind GeoIP # Stage 5: MaxMind GeoIP
FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip
@ -62,7 +61,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 6: Run # Stage 6: Run
FROM docker.io/python:3.11.5-slim-bookworm AS final-image FROM docker.io/python:3.11.4-slim-bullseye AS final-image
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ARG VERSION ARG VERSION
@ -82,13 +81,13 @@ COPY --from=geoip /usr/share/GeoIP /geoip
RUN apt-get update && \ RUN apt-get update && \
# Required for installing pip packages # Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev python3-dev && \ apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev && \
# Required for runtime # Required for runtime
apt-get install -y --no-install-recommends libpq5 openssl libxmlsec1-openssl libmaxminddb0 && \ apt-get install -y --no-install-recommends libxmlsec1-openssl libmaxminddb0 && \
# Required for bootstrap & healtcheck # Required for bootstrap & healtcheck
apt-get install -y --no-install-recommends runit && \ apt-get install -y --no-install-recommends runit && \
pip install --no-cache-dir -r /requirements.txt && \ pip install --no-cache-dir -r /requirements.txt && \
apt-get remove --purge -y build-essential pkg-config libxmlsec1-dev libpq-dev python3-dev && \ apt-get remove --purge -y build-essential pkg-config libxmlsec1-dev && \
apt-get autoremove --purge -y && \ apt-get autoremove --purge -y && \
apt-get clean && \ apt-get clean && \
rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \ rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \
@ -105,7 +104,7 @@ COPY ./tests /tests
COPY ./manage.py / COPY ./manage.py /
COPY ./blueprints /blueprints COPY ./blueprints /blueprints
COPY ./lifecycle/ /lifecycle COPY ./lifecycle/ /lifecycle
COPY --from=go-builder /work/bin/authentik /bin/authentik COPY --from=go-builder /work/authentik /bin/authentik
COPY --from=web-builder /work/web/dist/ /web/dist/ COPY --from=web-builder /work/web/dist/ /web/dist/
COPY --from=web-builder /work/web/authentik/ /web/authentik/ COPY --from=web-builder /work/web/authentik/ /web/authentik/
COPY --from=website-builder /work/website/help/ /website/help/ COPY --from=website-builder /work/website/help/ /website/help/

View File

@ -140,15 +140,13 @@ web-watch:
touch web/dist/.gitkeep touch web/dist/.gitkeep
cd web && npm run watch cd web && npm run watch
web-storybook-watch:
cd web && npm run storybook
web-lint-fix: web-lint-fix:
cd web && npm run prettier cd web && npm run prettier
web-lint: web-lint:
cd web && npm run lint cd web && npm run lint
cd web && npm run lit-analyse # TODO: The analyzer hasn't run correctly in awhile.
# cd web && npm run lit-analyse
web-check-compile: web-check-compile:
cd web && npm run tsc cd web && npm run tsc

View File

@ -16,8 +16,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
| Version | Supported | | Version | Supported |
| --- | --- | | --- | --- |
| 2023.5.x | ✅ |
| 2023.6.x | ✅ | | 2023.6.x | ✅ |
| 2023.8.x | ✅ |
## Reporting a Vulnerability ## Reporting a Vulnerability
@ -27,8 +27,6 @@ To report a vulnerability, send an email to [security@goauthentik.io](mailto:se
authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories: authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories:
| Score | Severity |
| --- | --- |
| 0.0 | None | | 0.0 | None |
| 0.1 3.9 | Low | | 0.1 3.9 | Low |
| 4.0 6.9 | Medium | | 4.0 6.9 | Medium |

View File

@ -1,8 +1,8 @@
"""authentik root module""" """authentik"""
from os import environ from os import environ
from typing import Optional from typing import Optional
__version__ = "2023.8.3" __version__ = "2023.6.1"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -2,43 +2,6 @@
from rest_framework import pagination from rest_framework import pagination
from rest_framework.response import Response from rest_framework.response import Response
PAGINATION_COMPONENT_NAME = "Pagination"
PAGINATION_SCHEMA = {
"type": "object",
"properties": {
"next": {
"type": "number",
},
"previous": {
"type": "number",
},
"count": {
"type": "number",
},
"current": {
"type": "number",
},
"total_pages": {
"type": "number",
},
"start_index": {
"type": "number",
},
"end_index": {
"type": "number",
},
},
"required": [
"next",
"previous",
"count",
"current",
"total_pages",
"start_index",
"end_index",
],
}
class Pagination(pagination.PageNumberPagination): class Pagination(pagination.PageNumberPagination):
"""Pagination which includes total pages and current page""" """Pagination which includes total pages and current page"""
@ -72,7 +35,41 @@ class Pagination(pagination.PageNumberPagination):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"pagination": {"$ref": f"#/components/schemas/{PAGINATION_COMPONENT_NAME}"}, "pagination": {
"type": "object",
"properties": {
"next": {
"type": "number",
},
"previous": {
"type": "number",
},
"count": {
"type": "number",
},
"current": {
"type": "number",
},
"total_pages": {
"type": "number",
},
"start_index": {
"type": "number",
},
"end_index": {
"type": "number",
},
},
"required": [
"next",
"previous",
"count",
"current",
"total_pages",
"start_index",
"end_index",
],
},
"results": schema, "results": schema,
}, },
"required": ["pagination", "results"], "required": ["pagination", "results"],

View File

@ -1,6 +1,5 @@
"""Error Response schema, from https://github.com/axnsan12/drf-yasg/issues/224""" """Error Response schema, from https://github.com/axnsan12/drf-yasg/issues/224"""
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from drf_spectacular.generators import SchemaGenerator
from drf_spectacular.plumbing import ( from drf_spectacular.plumbing import (
ResolvedComponent, ResolvedComponent,
build_array_type, build_array_type,
@ -9,9 +8,6 @@ from drf_spectacular.plumbing import (
) )
from drf_spectacular.settings import spectacular_settings from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from rest_framework.settings import api_settings
from authentik.api.pagination import PAGINATION_COMPONENT_NAME, PAGINATION_SCHEMA
def build_standard_type(obj, **kwargs): def build_standard_type(obj, **kwargs):
@ -32,7 +28,7 @@ GENERIC_ERROR = build_object_type(
VALIDATION_ERROR = build_object_type( VALIDATION_ERROR = build_object_type(
description=_("Validation Error"), description=_("Validation Error"),
properties={ properties={
api_settings.NON_FIELD_ERRORS_KEY: build_array_type(build_standard_type(OpenApiTypes.STR)), "non_field_errors": build_array_type(build_standard_type(OpenApiTypes.STR)),
"code": build_standard_type(OpenApiTypes.STR), "code": build_standard_type(OpenApiTypes.STR),
}, },
required=[], required=[],
@ -40,19 +36,7 @@ VALIDATION_ERROR = build_object_type(
) )
def create_component(generator: SchemaGenerator, name, schema, type_=ResolvedComponent.SCHEMA): def postprocess_schema_responses(result, generator, **kwargs): # noqa: W0613
"""Register a component and return a reference to it."""
component = ResolvedComponent(
name=name,
type=type_,
schema=schema,
object=name,
)
generator.registry.register_on_missing(component)
return component
def postprocess_schema_responses(result, generator: SchemaGenerator, **kwargs): # noqa: W0613
"""Workaround to set a default response for endpoints. """Workaround to set a default response for endpoints.
Workaround suggested at Workaround suggested at
<https://github.com/tfranzel/drf-spectacular/issues/119#issuecomment-656970357> <https://github.com/tfranzel/drf-spectacular/issues/119#issuecomment-656970357>
@ -60,10 +44,19 @@ def postprocess_schema_responses(result, generator: SchemaGenerator, **kwargs):
<https://github.com/tfranzel/drf-spectacular/issues/101>. <https://github.com/tfranzel/drf-spectacular/issues/101>.
""" """
create_component(generator, PAGINATION_COMPONENT_NAME, PAGINATION_SCHEMA) def create_component(name, schema, type_=ResolvedComponent.SCHEMA):
"""Register a component and return a reference to it."""
component = ResolvedComponent(
name=name,
type=type_,
schema=schema,
object=name,
)
generator.registry.register_on_missing(component)
return component
generic_error = create_component(generator, "GenericError", GENERIC_ERROR) generic_error = create_component("GenericError", GENERIC_ERROR)
validation_error = create_component(generator, "ValidationError", VALIDATION_ERROR) validation_error = create_component("ValidationError", VALIDATION_ERROR)
for path in result["paths"].values(): for path in result["paths"].values():
for method in path.values(): for method in path.values():

View File

@ -93,10 +93,10 @@ class ConfigView(APIView):
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
}, },
"capabilities": self.get_capabilities(), "capabilities": self.get_capabilities(),
"cache_timeout": CONFIG.get_int("redis.cache_timeout"), "cache_timeout": int(CONFIG.get("redis.cache_timeout")),
"cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"), "cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
"cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"), "cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
"cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"), "cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
} }
) )

View File

@ -45,8 +45,3 @@ entries:
attrs: attrs:
name: "%(uid)s" name: "%(uid)s"
password: "%(uid)s" password: "%(uid)s"
- model: authentik_core.user
identifiers:
username: "%(uid)s-no-password"
attrs:
name: "%(uid)s"

View File

@ -7,5 +7,7 @@ entries:
state: absent state: absent
- identifiers: - identifiers:
name: "%(id)s" name: "%(id)s"
expression: |
return True
model: authentik_policies_expression.expressionpolicy model: authentik_policies_expression.expressionpolicy
state: absent state: absent

View File

@ -9,8 +9,6 @@ context:
mapping: mapping:
key1: value key1: value
key2: 2 key2: 2
context1: context-nested-value
context2: !Context context1
entries: entries:
- model: !Format ["%s", authentik_sources_oauth.oauthsource] - model: !Format ["%s", authentik_sources_oauth.oauthsource]
state: !Format ["%s", present] state: !Format ["%s", present]
@ -36,7 +34,6 @@ entries:
model: authentik_policies_expression.expressionpolicy model: authentik_policies_expression.expressionpolicy
- attrs: - attrs:
attributes: attributes:
env_null: !Env [bar-baz, null]
policy_pk1: policy_pk1:
!Format [ !Format [
"%s-%s", "%s-%s",
@ -100,7 +97,6 @@ entries:
[list, with, items, !Format ["foo-%s", !Context foo]], [list, with, items, !Format ["foo-%s", !Context foo]],
] ]
if_true_simple: !If [!Context foo, true, text] if_true_simple: !If [!Context foo, true, text]
if_short: !If [!Context foo]
if_false_simple: !If [null, false, 2] if_false_simple: !If [null, false, 2]
enumerate_mapping_to_mapping: !Enumerate [ enumerate_mapping_to_mapping: !Enumerate [
!Context mapping, !Context mapping,
@ -145,7 +141,6 @@ entries:
] ]
] ]
] ]
nested_context: !Context context2
identifiers: identifiers:
name: test name: test
conditions: conditions:

View File

@ -155,7 +155,6 @@ class TestBlueprintsV1(TransactionTestCase):
}, },
"if_false_complex": ["list", "with", "items", "foo-bar"], "if_false_complex": ["list", "with", "items", "foo-bar"],
"if_true_simple": True, "if_true_simple": True,
"if_short": True,
"if_false_simple": 2, "if_false_simple": 2,
"enumerate_mapping_to_mapping": { "enumerate_mapping_to_mapping": {
"prefix-key1": "other-prefix-value", "prefix-key1": "other-prefix-value",
@ -212,10 +211,8 @@ class TestBlueprintsV1(TransactionTestCase):
], ],
}, },
}, },
"nested_context": "context-nested-value",
"env_null": None,
} }
).exists() )
) )
self.assertTrue( self.assertTrue(
OAuthSource.objects.filter( OAuthSource.objects.filter(

View File

@ -51,9 +51,3 @@ class TestBlueprintsV1ConditionalFields(TransactionTestCase):
user: User = User.objects.filter(username=self.uid).first() user: User = User.objects.filter(username=self.uid).first()
self.assertIsNotNone(user) self.assertIsNotNone(user)
self.assertTrue(user.check_password(self.uid)) self.assertTrue(user.check_password(self.uid))
def test_user_null(self):
"""Test user"""
user: User = User.objects.filter(username=f"{self.uid}-no-password").first()
self.assertIsNotNone(user)
self.assertFalse(user.has_usable_password())

View File

@ -223,11 +223,11 @@ class Env(YAMLTag):
if isinstance(node, ScalarNode): if isinstance(node, ScalarNode):
self.key = node.value self.key = node.value
if isinstance(node, SequenceNode): if isinstance(node, SequenceNode):
self.key = loader.construct_object(node.value[0]) self.key = node.value[0].value
self.default = loader.construct_object(node.value[1]) self.default = node.value[1].value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
return getenv(self.key) or self.default return getenv(self.key, self.default)
class Context(YAMLTag): class Context(YAMLTag):
@ -242,15 +242,13 @@ class Context(YAMLTag):
if isinstance(node, ScalarNode): if isinstance(node, ScalarNode):
self.key = node.value self.key = node.value
if isinstance(node, SequenceNode): if isinstance(node, SequenceNode):
self.key = loader.construct_object(node.value[0]) self.key = node.value[0].value
self.default = loader.construct_object(node.value[1]) self.default = node.value[1].value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
value = self.default value = self.default
if self.key in blueprint.context: if self.key in blueprint.context:
value = blueprint.context[self.key] value = blueprint.context[self.key]
if isinstance(value, YAMLTag):
return value.resolve(entry, blueprint)
return value return value
@ -262,7 +260,7 @@ class Format(YAMLTag):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.format_string = loader.construct_object(node.value[0]) self.format_string = node.value[0].value
self.args = [] self.args = []
for raw_node in node.value[1:]: for raw_node in node.value[1:]:
self.args.append(loader.construct_object(raw_node)) self.args.append(loader.construct_object(raw_node))
@ -341,7 +339,7 @@ class Condition(YAMLTag):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.mode = loader.construct_object(node.value[0]) self.mode = node.value[0].value
self.args = [] self.args = []
for raw_node in node.value[1:]: for raw_node in node.value[1:]:
self.args.append(loader.construct_object(raw_node)) self.args.append(loader.construct_object(raw_node))
@ -374,12 +372,8 @@ class If(YAMLTag):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.condition = loader.construct_object(node.value[0]) self.condition = loader.construct_object(node.value[0])
if len(node.value) == 1: self.when_true = loader.construct_object(node.value[1])
self.when_true = True self.when_false = loader.construct_object(node.value[2])
self.when_false = False
else:
self.when_true = loader.construct_object(node.value[1])
self.when_false = loader.construct_object(node.value[2])
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
if isinstance(self.condition, YAMLTag): if isinstance(self.condition, YAMLTag):
@ -416,7 +410,7 @@ class Enumerate(YAMLTag, YAMLTagContext):
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
super().__init__() super().__init__()
self.iterable = loader.construct_object(node.value[0]) self.iterable = loader.construct_object(node.value[0])
self.output_body = loader.construct_object(node.value[1]) self.output_body = node.value[1].value
self.item_body = loader.construct_object(node.value[2]) self.item_body = loader.construct_object(node.value[2])
self.__current_context: tuple[Any, Any] = tuple() self.__current_context: tuple[Any, Any] = tuple()

View File

@ -35,7 +35,6 @@ from authentik.core.models import (
Source, Source,
UserSourceConnection, UserSourceConnection,
) )
from authentik.events.utils import cleanse_dict
from authentik.flows.models import FlowToken, Stage from authentik.flows.models import FlowToken, Stage
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.outposts.models import OutpostServiceConnection from authentik.outposts.models import OutpostServiceConnection
@ -200,6 +199,9 @@ class Importer:
serializer_kwargs = {} serializer_kwargs = {}
model_instance = existing_models.first() model_instance = existing_models.first()
if not isinstance(model(), BaseMetaModel) and model_instance: if not isinstance(model(), BaseMetaModel) and model_instance:
if entry.get_state(self.__import) == BlueprintEntryDesiredState.CREATED:
self.logger.debug("instance exists, skipping")
return None
self.logger.debug( self.logger.debug(
"initialise serializer with instance", "initialise serializer with instance",
model=model, model=model,
@ -210,9 +212,7 @@ class Importer:
serializer_kwargs["partial"] = True serializer_kwargs["partial"] = True
else: else:
self.logger.debug( self.logger.debug(
"initialised new serializer instance", "initialised new serializer instance", model=model, **updated_identifiers
model=model,
**cleanse_dict(updated_identifiers),
) )
model_instance = model() model_instance = model()
# pk needs to be set on the model instance otherwise a new one will be generated # pk needs to be set on the model instance otherwise a new one will be generated
@ -268,34 +268,21 @@ class Importer:
try: try:
serializer = self._validate_single(entry) serializer = self._validate_single(entry)
except EntryInvalidError as exc: except EntryInvalidError as exc:
# For deleting objects we don't need the serializer to be valid
if entry.get_state(self.__import) == BlueprintEntryDesiredState.ABSENT:
continue
self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc)
return False return False
if not serializer: if not serializer:
continue continue
state = entry.get_state(self.__import) state = entry.get_state(self.__import)
if state in [BlueprintEntryDesiredState.PRESENT, BlueprintEntryDesiredState.CREATED]: if state in [
instance = serializer.instance BlueprintEntryDesiredState.PRESENT,
if ( BlueprintEntryDesiredState.CREATED,
instance ]:
and not instance._state.adding model = serializer.save()
and state == BlueprintEntryDesiredState.CREATED
):
self.logger.debug(
"instance exists, skipping",
model=model,
instance=instance,
pk=instance.pk,
)
else:
instance = serializer.save()
self.logger.debug("updated model", model=instance)
if "pk" in entry.identifiers: if "pk" in entry.identifiers:
self.__pk_map[entry.identifiers["pk"]] = instance.pk self.__pk_map[entry.identifiers["pk"]] = model.pk
entry._state = BlueprintEntryState(instance) entry._state = BlueprintEntryState(model)
self.logger.debug("updated model", model=model)
elif state == BlueprintEntryDesiredState.ABSENT: elif state == BlueprintEntryDesiredState.ABSENT:
instance: Optional[Model] = serializer.instance instance: Optional[Model] = serializer.instance
if instance.pk: if instance.pk:
@ -322,6 +309,5 @@ class Importer:
self.logger.debug("Blueprint validation failed") self.logger.debug("Blueprint validation failed")
for log in logs: for log in logs:
getattr(self.logger, log.get("log_level"))(**log) getattr(self.logger, log.get("log_level"))(**log)
self.logger.debug("Finished blueprint import validation")
self.__import = orig_import self.__import = orig_import
return successful, logs return successful, logs

View File

@ -31,7 +31,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
required = attrs["required"] required = attrs["required"]
instance = BlueprintInstance.objects.filter(**identifiers).first() instance = BlueprintInstance.objects.filter(**identifiers).first()
if not instance and required: if not instance and required:
raise ValidationError({"identifiers": "Required blueprint does not exist"}) raise ValidationError("Required blueprint does not exist")
self.blueprint_instance = instance self.blueprint_instance = instance
return super().validate(attrs) return super().validate(attrs)

View File

@ -49,7 +49,7 @@ class GroupSerializer(ModelSerializer):
users_obj = ListSerializer( users_obj = ListSerializer(
child=GroupMemberSerializer(), read_only=True, source="users", required=False child=GroupMemberSerializer(), read_only=True, source="users", required=False
) )
parent_name = CharField(source="parent.name", read_only=True, allow_null=True) parent_name = CharField(source="parent.name", read_only=True)
num_pk = IntegerField(read_only=True) num_pk = IntegerField(read_only=True)

View File

@ -47,7 +47,7 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
attrs.setdefault("user", request.user) attrs.setdefault("user", request.user)
attrs.setdefault("intent", TokenIntents.INTENT_API) attrs.setdefault("intent", TokenIntents.INTENT_API)
if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]: if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]:
raise ValidationError({"intent": f"Invalid intent {attrs.get('intent')}"}) raise ValidationError(f"Invalid intent {attrs.get('intent')}")
return attrs return attrs
class Meta: class Meta:

View File

@ -15,13 +15,7 @@ from django.utils.http import urlencode
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_filters.filters import ( from django_filters.filters import BooleanFilter, CharFilter, ModelMultipleChoiceFilter, UUIDFilter
BooleanFilter,
CharFilter,
ModelMultipleChoiceFilter,
MultipleChoiceFilter,
UUIDFilter,
)
from django_filters.filterset import FilterSet from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import ( from drf_spectacular.utils import (
@ -123,35 +117,27 @@ class UserSerializer(ModelSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if SERIALIZER_CONTEXT_BLUEPRINT in self.context: if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
self.fields["password"] = CharField(required=False, allow_null=True) self.fields["password"] = CharField(required=False)
def create(self, validated_data: dict) -> User: def create(self, validated_data: dict) -> User:
"""If this serializer is used in the blueprint context, we allow for """If this serializer is used in the blueprint context, we allow for
directly setting a password. However should be done via the `set_password` directly setting a password. However should be done via the `set_password`
method instead of directly setting it like rest_framework.""" method instead of directly setting it like rest_framework."""
password = validated_data.pop("password", None)
instance: User = super().create(validated_data) instance: User = super().create(validated_data)
self._set_password(instance, password) if SERIALIZER_CONTEXT_BLUEPRINT in self.context and "password" in validated_data:
instance.set_password(validated_data["password"])
instance.save()
return instance return instance
def update(self, instance: User, validated_data: dict) -> User: def update(self, instance: User, validated_data: dict) -> User:
"""Same as `create` above, set the password directly if we're in a blueprint """Same as `create` above, set the password directly if we're in a blueprint
context""" context"""
password = validated_data.pop("password", None)
instance = super().update(instance, validated_data) instance = super().update(instance, validated_data)
self._set_password(instance, password) if SERIALIZER_CONTEXT_BLUEPRINT in self.context and "password" in validated_data:
instance.set_password(validated_data["password"])
instance.save()
return instance return instance
def _set_password(self, instance: User, password: Optional[str]):
"""Set password of user if we're in a blueprint context, and if it's an empty
string then use an unusable password"""
if SERIALIZER_CONTEXT_BLUEPRINT in self.context and password:
instance.set_password(password)
instance.save()
if len(instance.password) == 0:
instance.set_unusable_password()
instance.save()
def validate_path(self, path: str) -> str: def validate_path(self, path: str) -> str:
"""Validate path""" """Validate path"""
if path[:1] == "/" or path[-1] == "/": if path[:1] == "/" or path[-1] == "/":
@ -215,7 +201,7 @@ class UserSelfSerializer(ModelSerializer):
) )
def get_groups(self, _: User): def get_groups(self, _: User):
"""Return only the group names a user is member of""" """Return only the group names a user is member of"""
for group in self.instance.all_groups().order_by("name"): for group in self.instance.ak_groups.all():
yield { yield {
"name": group.name, "name": group.name,
"pk": group.pk, "pk": group.pk,
@ -314,11 +300,11 @@ class UsersFilter(FilterSet):
is_superuser = BooleanFilter(field_name="ak_groups", lookup_expr="is_superuser") is_superuser = BooleanFilter(field_name="ak_groups", lookup_expr="is_superuser")
uuid = UUIDFilter(field_name="uuid") uuid = UUIDFilter(field_name="uuid")
path = CharFilter(field_name="path") path = CharFilter(
field_name="path",
)
path_startswith = CharFilter(field_name="path", lookup_expr="startswith") path_startswith = CharFilter(field_name="path", lookup_expr="startswith")
type = MultipleChoiceFilter(choices=UserTypes.choices, field_name="type")
groups_by_name = ModelMultipleChoiceFilter( groups_by_name = ModelMultipleChoiceFilter(
field_name="ak_groups__name", field_name="ak_groups__name",
to_field_name="name", to_field_name="name",

View File

@ -1,21 +0,0 @@
"""Build source docs"""
from pathlib import Path
from django.core.management.base import BaseCommand
from pdoc import pdoc
from pdoc.render import configure
class Command(BaseCommand):
"""Build source docs"""
def handle(self, **options):
configure(
docformat="markdown",
mermaid=True,
logo="https://goauthentik.io/img/icon_top_brand_colour.svg",
)
pdoc(
"authentik",
output_directory=Path("./source_docs"),
)

View File

@ -26,6 +26,7 @@ class Command(BaseCommand):
no_color=False, no_color=False,
quiet=True, quiet=True,
optimization="fair", optimization="fair",
max_tasks_per_child=1,
autoscale=(3, 1), autoscale=(3, 1),
task_events=True, task_events=True,
beat=True, beat=True,

View File

@ -1,11 +1,55 @@
# Generated by Django 3.2.8 on 2021-10-10 16:16 # Generated by Django 3.2.8 on 2021-10-10 16:16
from os import environ
import django.db.models.deletion import django.db.models.deletion
from django.apps.registry import Apps
from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
import authentik.core.models import authentik.core.models
def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from django.contrib.auth.hashers import make_password
User = apps.get_model("authentik_core", "User")
db_alias = schema_editor.connection.alias
akadmin, _ = User.objects.using(db_alias).get_or_create(
username="akadmin",
email=environ.get("AUTHENTIK_BOOTSTRAP_EMAIL", "root@localhost"),
name="authentik Default Admin",
)
password = None
if "TF_BUILD" in environ or settings.TEST:
password = "akadmin" # noqa # nosec
if "AUTHENTIK_BOOTSTRAP_PASSWORD" in environ:
password = environ["AUTHENTIK_BOOTSTRAP_PASSWORD"]
if password:
akadmin.password = make_password(password)
else:
akadmin.password = make_password(None)
akadmin.save()
def create_default_admin_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group")
User = apps.get_model("authentik_core", "User")
# Creates a default admin group
group, _ = Group.objects.using(db_alias).get_or_create(
is_superuser=True,
defaults={
"name": "authentik Admins",
},
)
group.users.set(User.objects.filter(username="akadmin"))
group.save()
class Migration(migrations.Migration): class Migration(migrations.Migration):
replaces = [ replaces = [
("authentik_core", "0002_auto_20200523_1133"), ("authentik_core", "0002_auto_20200523_1133"),
@ -75,6 +119,9 @@ class Migration(migrations.Migration):
model_name="user", model_name="user",
name="is_staff", name="is_staff",
), ),
migrations.RunPython(
code=create_default_user,
),
migrations.AddField( migrations.AddField(
model_name="user", model_name="user",
name="is_superuser", name="is_superuser",
@ -154,6 +201,9 @@ class Migration(migrations.Migration):
default=False, help_text="Users added to this group will be superusers." default=False, help_text="Users added to this group will be superusers."
), ),
), ),
migrations.RunPython(
code=create_default_admin_group,
),
migrations.AlterModelManagers( migrations.AlterModelManagers(
name="user", name="user",
managers=[ managers=[

View File

@ -1,6 +1,7 @@
# Generated by Django 3.2.8 on 2021-10-10 16:12 # Generated by Django 3.2.8 on 2021-10-10 16:12
import uuid import uuid
from os import environ
import django.db.models.deletion import django.db.models.deletion
from django.apps.registry import Apps from django.apps.registry import Apps
@ -34,6 +35,29 @@ def fix_duplicates(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Token.objects.using(db_alias).filter(identifier=ident["identifier"]).delete() Token.objects.using(db_alias).filter(identifier=ident["identifier"]).delete()
def create_default_user_token(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from authentik.core.models import TokenIntents
User = apps.get_model("authentik_core", "User")
Token = apps.get_model("authentik_core", "Token")
db_alias = schema_editor.connection.alias
akadmin = User.objects.using(db_alias).filter(username="akadmin")
if not akadmin.exists():
return
if "AUTHENTIK_BOOTSTRAP_TOKEN" not in environ:
return
key = environ["AUTHENTIK_BOOTSTRAP_TOKEN"]
Token.objects.using(db_alias).create(
identifier="authentik-bootstrap-token",
user=akadmin.first(),
intent=TokenIntents.INTENT_API,
expiring=False,
key=key,
)
class Migration(migrations.Migration): class Migration(migrations.Migration):
replaces = [ replaces = [
("authentik_core", "0018_auto_20210330_1345"), ("authentik_core", "0018_auto_20210330_1345"),
@ -190,6 +214,9 @@ class Migration(migrations.Migration):
"verbose_name_plural": "Authenticated Sessions", "verbose_name_plural": "Authenticated Sessions",
}, },
), ),
migrations.RunPython(
code=create_default_user_token,
),
migrations.AlterField( migrations.AlterField(
model_name="token", model_name="token",
name="intent", name="intent",

View File

@ -60,7 +60,7 @@ def default_token_key():
"""Default token key""" """Default token key"""
# We use generate_id since the chars in the key should be easy # We use generate_id since the chars in the key should be easy
# to use in Emails (for verification) and URLs (for recovery) # to use in Emails (for verification) and URLs (for recovery)
return generate_id(CONFIG.get_int("default_token_length")) return generate_id(int(CONFIG.get("default_token_length")))
class UserTypes(models.TextChoices): class UserTypes(models.TextChoices):
@ -79,7 +79,7 @@ class UserTypes(models.TextChoices):
class Group(SerializerModel): class Group(SerializerModel):
"""Group model which supports a basic hierarchy and has attributes""" """Custom Group model which supports a basic hierarchy"""
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
@ -113,7 +113,27 @@ class Group(SerializerModel):
def is_member(self, user: "User") -> bool: def is_member(self, user: "User") -> bool:
"""Recursively check if `user` is member of us, or any parent.""" """Recursively check if `user` is member of us, or any parent."""
return user.all_groups().filter(group_uuid=self.group_uuid).exists() query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid = %s
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth - 1
FROM authentik_core_group,parents
WHERE (
authentik_core_group.parent_id = parents.group_uuid and
parents.relative_depth > -20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid;
"""
groups = Group.objects.raw(query, [self.group_uuid])
return user.ak_groups.filter(pk__in=[group.pk for group in groups]).exists()
def __str__(self): def __str__(self):
return f"Group {self.name}" return f"Group {self.name}"
@ -128,15 +148,15 @@ class Group(SerializerModel):
class UserManager(DjangoUserManager): class UserManager(DjangoUserManager):
"""User manager that doesn't assign is_superuser and is_staff""" """Custom user manager that doesn't assign is_superuser and is_staff"""
def create_user(self, username, email=None, password=None, **extra_fields): def create_user(self, username, email=None, password=None, **extra_fields):
"""User manager that doesn't assign is_superuser and is_staff""" """Custom user manager that doesn't assign is_superuser and is_staff"""
return self._create_user(username, email, password, **extra_fields) return self._create_user(username, email, password, **extra_fields)
class User(SerializerModel, GuardianUserMixin, AbstractUser): class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model.""" """Custom User model to allow easier adding of user-based settings"""
uuid = models.UUIDField(default=uuid4, editable=False, unique=True) uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
name = models.TextField(help_text=_("User's display name.")) name = models.TextField(help_text=_("User's display name."))
@ -156,45 +176,13 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""Get the default user path""" """Get the default user path"""
return User._meta.get_field("path").default return User._meta.get_field("path").default
def all_groups(self) -> QuerySet[Group]:
"""Recursively get all groups this user is a member of.
At least one query is done to get the direct groups of the user, with groups
there are at most 3 queries done"""
direct_groups = list(
x for x in self.ak_groups.all().values_list("pk", flat=True).iterator()
)
if len(direct_groups) < 1:
return Group.objects.none()
query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid = ANY(%s)
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth + 1
FROM authentik_core_group, parents
WHERE (
authentik_core_group.group_uuid = parents.parent_id and
parents.relative_depth < 20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid, name
ORDER BY name;
"""
group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()]
return Group.objects.filter(pk__in=group_pks)
def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]: def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]:
"""Get a dictionary containing the attributes from all groups the user belongs to, """Get a dictionary containing the attributes from all groups the user belongs to,
including the users attributes""" including the users attributes"""
final_attributes = {} final_attributes = {}
if request and hasattr(request, "tenant"): if request and hasattr(request, "tenant"):
always_merger.merge(final_attributes, request.tenant.attributes) always_merger.merge(final_attributes, request.tenant.attributes)
for group in self.all_groups().order_by("name"): for group in self.ak_groups.all().order_by("name"):
always_merger.merge(final_attributes, group.attributes) always_merger.merge(final_attributes, group.attributes)
always_merger.merge(final_attributes, self.attributes) always_merger.merge(final_attributes, self.attributes)
return final_attributes return final_attributes
@ -208,7 +196,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
@cached_property @cached_property
def is_superuser(self) -> bool: def is_superuser(self) -> bool:
"""Get supseruser status based on membership in a group with superuser status""" """Get supseruser status based on membership in a group with superuser status"""
return self.all_groups().filter(is_superuser=True).exists() return self.ak_groups.filter(is_superuser=True).exists()
@property @property
def is_staff(self) -> bool: def is_staff(self) -> bool:

View File

@ -78,6 +78,7 @@
</main> </main>
{% endblock %} {% endblock %}
<footer class="pf-c-login__footer"> <footer class="pf-c-login__footer">
<p></p>
<ul class="pf-c-list pf-m-inline"> <ul class="pf-c-list pf-m-inline">
{% for link in footer_links %} {% for link in footer_links %}
<li> <li>

View File

@ -13,9 +13,7 @@ class TestGroups(TestCase):
user = User.objects.create(username=generate_id()) user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id())
group = Group.objects.create(name=generate_id()) group = Group.objects.create(name=generate_id())
other_group = Group.objects.create(name=generate_id())
group.users.add(user) group.users.add(user)
other_group.users.add(user)
self.assertTrue(group.is_member(user)) self.assertTrue(group.is_member(user))
self.assertFalse(group.is_member(user2)) self.assertFalse(group.is_member(user2))
@ -23,26 +21,22 @@ class TestGroups(TestCase):
"""Test parent membership""" """Test parent membership"""
user = User.objects.create(username=generate_id()) user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id())
parent = Group.objects.create(name=generate_id()) first = Group.objects.create(name=generate_id())
child = Group.objects.create(name=generate_id(), parent=parent) second = Group.objects.create(name=generate_id(), parent=first)
child.users.add(user) second.users.add(user)
self.assertTrue(child.is_member(user)) self.assertTrue(first.is_member(user))
self.assertTrue(parent.is_member(user)) self.assertFalse(first.is_member(user2))
self.assertFalse(child.is_member(user2))
self.assertFalse(parent.is_member(user2))
def test_group_membership_parent_extra(self): def test_group_membership_parent_extra(self):
"""Test parent membership""" """Test parent membership"""
user = User.objects.create(username=generate_id()) user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id())
parent = Group.objects.create(name=generate_id()) first = Group.objects.create(name=generate_id())
second = Group.objects.create(name=generate_id(), parent=parent) second = Group.objects.create(name=generate_id(), parent=first)
third = Group.objects.create(name=generate_id(), parent=second) third = Group.objects.create(name=generate_id(), parent=second)
second.users.add(user) second.users.add(user)
self.assertTrue(parent.is_member(user)) self.assertTrue(first.is_member(user))
self.assertFalse(parent.is_member(user2)) self.assertFalse(first.is_member(user2))
self.assertTrue(second.is_member(user))
self.assertFalse(second.is_member(user2))
self.assertFalse(third.is_member(user)) self.assertFalse(third.is_member(user))
self.assertFalse(third.is_member(user2)) self.assertFalse(third.is_member(user2))

View File

@ -28,19 +28,6 @@ class TestUsersAPI(APITestCase):
self.admin = create_test_admin_user() self.admin = create_test_admin_user()
self.user = User.objects.create(username="test-user") self.user = User.objects.create(username="test-user")
def test_filter_type(self):
"""Test API filtering by type"""
self.client.force_login(self.admin)
user = create_test_admin_user(type=UserTypes.EXTERNAL)
response = self.client.get(
reverse("authentik_api:user-list"),
data={
"type": UserTypes.EXTERNAL,
"username": user.username,
},
)
self.assertEqual(response.status_code, 200)
def test_metrics(self): def test_metrics(self):
"""Test user's metrics""" """Test user's metrics"""
self.client.force_login(self.admin) self.client.force_login(self.admin)

View File

@ -21,7 +21,7 @@ def create_test_flow(
) )
def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User: def create_test_admin_user(name: Optional[str] = None) -> User:
"""Generate a test-admin user""" """Generate a test-admin user"""
uid = generate_id(20) if not name else name uid = generate_id(20) if not name else name
group = Group.objects.create(name=uid, is_superuser=True) group = Group.objects.create(name=uid, is_superuser=True)
@ -29,7 +29,6 @@ def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User:
username=uid, username=uid,
name=uid, name=uid,
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
**kwargs,
) )
user.set_password(uid) user.set_password(uid)
user.save() user.save()
@ -37,12 +36,12 @@ def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User:
return user return user
def create_test_tenant(**kwargs) -> Tenant: def create_test_tenant() -> Tenant:
"""Generate a test tenant, removing all other tenants to make sure this one """Generate a test tenant, removing all other tenants to make sure this one
matches.""" matches."""
uid = generate_id(20) uid = generate_id(20)
Tenant.objects.all().delete() Tenant.objects.all().delete()
return Tenant.objects.create(domain=uid, default=True, **kwargs) return Tenant.objects.create(domain=uid, default=True)
def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair: def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair:

View File

@ -189,8 +189,6 @@ class CertificateKeyPairFilter(FilterSet):
def filter_has_key(self, queryset, name, value): # pragma: no cover def filter_has_key(self, queryset, name, value): # pragma: no cover
"""Only return certificate-key pairs with keys""" """Only return certificate-key pairs with keys"""
if not value:
return queryset
return queryset.exclude(key_data__exact="") return queryset.exclude(key_data__exact="")
class Meta: class Meta:

View File

@ -128,26 +128,8 @@ class TestCrypto(APITestCase):
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_api:certificatekeypair-list", "authentik_api:certificatekeypair-list",
), )
data={"name": cert.name}, + f"?name={cert.name}"
)
self.assertEqual(200, response.status_code)
body = loads(response.content.decode())
api_cert = [x for x in body["results"] if x["name"] == cert.name][0]
self.assertEqual(api_cert["fingerprint_sha1"], cert.fingerprint_sha1)
self.assertEqual(api_cert["fingerprint_sha256"], cert.fingerprint_sha256)
def test_list_has_key_false(self):
"""Test API List with has_key set to false"""
cert = create_test_cert()
cert.key_data = ""
cert.save()
self.client.force_login(create_test_admin_user())
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-list",
),
data={"name": cert.name, "has_key": False},
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
body = loads(response.content.decode()) body = loads(response.content.decode())
@ -162,8 +144,8 @@ class TestCrypto(APITestCase):
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_api:certificatekeypair-list", "authentik_api:certificatekeypair-list",
), )
data={"name": cert.name, "include_details": False}, + f"?name={cert.name}&include_details=false"
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
body = loads(response.content.decode()) body = loads(response.content.decode())
@ -186,8 +168,8 @@ class TestCrypto(APITestCase):
reverse( reverse(
"authentik_api:certificatekeypair-view-certificate", "authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk}, kwargs={"pk": keypair.pk},
), )
data={"download": True}, + "?download",
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response) self.assertIn("Content-Disposition", response)
@ -207,8 +189,8 @@ class TestCrypto(APITestCase):
reverse( reverse(
"authentik_api:certificatekeypair-view-private-key", "authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk}, kwargs={"pk": keypair.pk},
), )
data={"download": True}, + "?download",
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response) self.assertIn("Content-Disposition", response)
@ -218,7 +200,7 @@ class TestCrypto(APITestCase):
self.client.force_login(create_test_admin_user()) self.client.force_login(create_test_admin_user())
keypair = create_test_cert() keypair = create_test_cert()
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name=generate_id(), name="test",
client_id="test", client_id="test",
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),

View File

@ -35,13 +35,13 @@ class LicenseSerializer(ModelSerializer):
"name", "name",
"key", "key",
"expiry", "expiry",
"internal_users", "users",
"external_users", "external_users",
] ]
extra_kwargs = { extra_kwargs = {
"name": {"read_only": True}, "name": {"read_only": True},
"expiry": {"read_only": True}, "expiry": {"read_only": True},
"internal_users": {"read_only": True}, "users": {"read_only": True},
"external_users": {"read_only": True}, "external_users": {"read_only": True},
} }
@ -49,7 +49,7 @@ class LicenseSerializer(ModelSerializer):
class LicenseSummary(PassiveSerializer): class LicenseSummary(PassiveSerializer):
"""Serializer for license status""" """Serializer for license status"""
internal_users = IntegerField(required=True) users = IntegerField(required=True)
external_users = IntegerField(required=True) external_users = IntegerField(required=True)
valid = BooleanField() valid = BooleanField()
show_admin_warning = BooleanField() show_admin_warning = BooleanField()
@ -62,9 +62,9 @@ class LicenseSummary(PassiveSerializer):
class LicenseForecastSerializer(PassiveSerializer): class LicenseForecastSerializer(PassiveSerializer):
"""Serializer for license forecast""" """Serializer for license forecast"""
internal_users = IntegerField(required=True) users = IntegerField(required=True)
external_users = IntegerField(required=True) external_users = IntegerField(required=True)
forecasted_internal_users = IntegerField(required=True) forecasted_users = IntegerField(required=True)
forecasted_external_users = IntegerField(required=True) forecasted_external_users = IntegerField(required=True)
@ -111,7 +111,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
latest_valid = datetime.fromtimestamp(total.exp) latest_valid = datetime.fromtimestamp(total.exp)
response = LicenseSummary( response = LicenseSummary(
data={ data={
"internal_users": total.internal_users, "users": total.users,
"external_users": total.external_users, "external_users": total.external_users,
"valid": total.is_valid(), "valid": total.is_valid(),
"show_admin_warning": show_admin_warning, "show_admin_warning": show_admin_warning,
@ -135,8 +135,8 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
def forecast(self, request: Request) -> Response: def forecast(self, request: Request) -> Response:
"""Forecast how many users will be required in a year""" """Forecast how many users will be required in a year"""
last_month = now() - timedelta(days=30) last_month = now() - timedelta(days=30)
# Forecast for internal users # Forecast for default users
internal_in_last_month = User.objects.filter( users_in_last_month = User.objects.filter(
type=UserTypes.INTERNAL, date_joined__gte=last_month type=UserTypes.INTERNAL, date_joined__gte=last_month
).count() ).count()
# Forecast for external users # Forecast for external users
@ -144,9 +144,9 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
forecast_for_months = 12 forecast_for_months = 12
response = LicenseForecastSerializer( response = LicenseForecastSerializer(
data={ data={
"internal_users": LicenseKey.get_default_user_count(), "users": LicenseKey.get_default_user_count(),
"external_users": LicenseKey.get_external_user_count(), "external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months), "forecasted_users": (users_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months), "forecasted_external_users": (external_in_last_month * forecast_for_months),
} }
) )

View File

@ -1,36 +0,0 @@
# Generated by Django 4.2.4 on 2023-08-23 10:06
import django.contrib.postgres.indexes
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_enterprise", "0001_initial"),
]
operations = [
migrations.RenameField(
model_name="license",
old_name="users",
new_name="internal_users",
),
migrations.AlterField(
model_name="license",
name="key",
field=models.TextField(),
),
migrations.AddIndex(
model_name="license",
index=django.contrib.postgres.indexes.HashIndex(
fields=["key"], name="authentik_e_key_523e13_hash"
),
),
migrations.AlterModelOptions(
name="licenseusage",
options={
"verbose_name": "License Usage",
"verbose_name_plural": "License Usage Records",
},
),
]

View File

@ -11,11 +11,9 @@ from uuid import uuid4
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict from dacite import from_dict
from django.contrib.postgres.indexes import HashIndex
from django.db import models from django.db import models
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from jwt import PyJWTError, decode, get_unverified_header from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
@ -48,8 +46,8 @@ class LicenseKey:
exp: int exp: int
name: str name: str
internal_users: int = 0 users: int
external_users: int = 0 external_users: int
flags: list[LicenseFlags] = field(default_factory=list) flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod @staticmethod
@ -89,7 +87,7 @@ class LicenseKey:
active_licenses = License.objects.filter(expiry__gte=now()) active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses: for lic in active_licenses:
total.internal_users += lic.internal_users total.users += lic.users
total.external_users += lic.external_users total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple())) exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0: if total.exp == 0:
@ -125,7 +123,7 @@ class LicenseKey:
Only checks the current count, no historical data is checked""" Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count() default_users = self.get_default_user_count()
if default_users > self.internal_users: if default_users > self.users:
return False return False
active_users = self.get_external_user_count() active_users = self.get_external_user_count()
if active_users > self.external_users: if active_users > self.external_users:
@ -155,11 +153,11 @@ class License(models.Model):
"""An authentik enterprise license""" """An authentik enterprise license"""
license_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) license_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
key = models.TextField() key = models.TextField(unique=True)
name = models.TextField() name = models.TextField()
expiry = models.DateTimeField() expiry = models.DateTimeField()
internal_users = models.BigIntegerField() users = models.BigIntegerField()
external_users = models.BigIntegerField() external_users = models.BigIntegerField()
@property @property
@ -167,9 +165,6 @@ class License(models.Model):
"""Get parsed license status""" """Get parsed license status"""
return LicenseKey.validate(self.key) return LicenseKey.validate(self.key)
class Meta:
indexes = (HashIndex(fields=("key",)),)
def usage_expiry(): def usage_expiry():
"""Keep license usage records for 3 months""" """Keep license usage records for 3 months"""
@ -188,7 +183,3 @@ class LicenseUsage(ExpiringModel):
within_limits = models.BooleanField() within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True) record_date = models.DateTimeField(auto_now_add=True)
class Meta:
verbose_name = _("License Usage")
verbose_name_plural = _("License Usage Records")

View File

@ -13,6 +13,6 @@ def pre_save_license(sender: type[License], instance: License, **_):
"""Extract data from license jwt and save it into model""" """Extract data from license jwt and save it into model"""
status = instance.status status = instance.status
instance.name = status.name instance.name = status.name
instance.internal_users = status.internal_users instance.users = status.users
instance.external_users = status.external_users instance.external_users = status.external_users
instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone())

View File

@ -23,7 +23,7 @@ class TestEnterpriseLicense(TestCase):
aud="", aud="",
exp=_exp, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, users=100,
external_users=100, external_users=100,
) )
), ),
@ -32,7 +32,7 @@ class TestEnterpriseLicense(TestCase):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.is_valid()) self.assertTrue(lic.status.is_valid())
self.assertEqual(lic.internal_users, 100) self.assertEqual(lic.users, 100)
def test_invalid(self): def test_invalid(self):
"""Test invalid license""" """Test invalid license"""
@ -46,7 +46,7 @@ class TestEnterpriseLicense(TestCase):
aud="", aud="",
exp=_exp, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, users=100,
external_users=100, external_users=100,
) )
), ),
@ -58,7 +58,7 @@ class TestEnterpriseLicense(TestCase):
lic2 = License.objects.create(key=generate_id()) lic2 = License.objects.create(key=generate_id())
self.assertTrue(lic2.status.is_valid()) self.assertTrue(lic2.status.is_valid())
total = LicenseKey.get_total() total = LicenseKey.get_total()
self.assertEqual(total.internal_users, 200) self.assertEqual(total.users, 200)
self.assertEqual(total.external_users, 200) self.assertEqual(total.external_users, 200)
self.assertEqual(total.exp, _exp) self.assertEqual(total.exp, _exp)
self.assertTrue(total.is_valid()) self.assertTrue(total.is_valid())

View File

@ -4,7 +4,7 @@ from json import loads
import django_filters import django_filters
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.fields.json import KeyTextTransform, KeyTransform from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import ExtractDay from django.db.models.functions import ExtractDay
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema from drf_spectacular.utils import OpenApiParameter, extend_schema
@ -134,11 +134,11 @@ class EventViewSet(ModelViewSet):
"""Get the top_n events grouped by user count""" """Get the top_n events grouped by user count"""
filtered_action = request.query_params.get("action", EventAction.LOGIN) filtered_action = request.query_params.get("action", EventAction.LOGIN)
top_n = int(request.query_params.get("top_n", "15")) top_n = int(request.query_params.get("top_n", "15"))
events = ( return Response(
get_objects_for_user(request.user, "authentik_events.view_event") get_objects_for_user(request.user, "authentik_events.view_event")
.filter(action=filtered_action) .filter(action=filtered_action)
.exclude(context__authorized_application=None) .exclude(context__authorized_application=None)
.annotate(application=KeyTransform("authorized_application", "context")) .annotate(application=KeyTextTransform("authorized_application", "context"))
.annotate(user_pk=KeyTextTransform("pk", "user")) .annotate(user_pk=KeyTextTransform("pk", "user"))
.values("application") .values("application")
.annotate(counted_events=Count("application")) .annotate(counted_events=Count("application"))
@ -146,7 +146,6 @@ class EventViewSet(ModelViewSet):
.values("unique_users", "application", "counted_events") .values("unique_users", "application", "counted_events")
.order_by("-counted_events")[:top_n] .order_by("-counted_events")[:top_n]
) )
return Response(EventTopPerUserSerializer(instance=events, many=True).data)
@extend_schema( @extend_schema(
methods=["GET"], methods=["GET"],

View File

@ -39,7 +39,7 @@ class NotificationTransportSerializer(ModelSerializer):
mode = attrs.get("mode") mode = attrs.get("mode")
if mode in [TransportMode.WEBHOOK, TransportMode.WEBHOOK_SLACK]: if mode in [TransportMode.WEBHOOK, TransportMode.WEBHOOK_SLACK]:
if "webhook_url" not in attrs or attrs.get("webhook_url", "") == "": if "webhook_url" not in attrs or attrs.get("webhook_url", "") == "":
raise ValidationError({"webhook_url": "Webhook URL may not be empty."}) raise ValidationError("Webhook URL may not be empty.")
return attrs return attrs
class Meta: class Meta:

View File

@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored. # was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored" PLAN_CONTEXT_IS_RESTORED = "is_restored"
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows") CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
CACHE_PREFIX = "goauthentik.io/flows/planner/" CACHE_PREFIX = "goauthentik.io/flows/planner/"

View File

@ -1,10 +0,0 @@
package lib
import _ "embed"
//go:embed default.yml
var defaultConfig []byte
func DefaultConfig() []byte {
return defaultConfig
}

View File

@ -213,14 +213,6 @@ class ConfigLoader:
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
return attr.value return attr.value
def get_int(self, path: str, default=0) -> int:
"""Wrapper for get that converts value into int"""
try:
return int(self.get(path, default))
except ValueError as exc:
self.log("warning", "Failed to parse config as int", path=path, exc=str(exc))
return default
def get_bool(self, path: str, default=False) -> bool: def get_bool(self, path: str, default=False) -> bool:
"""Wrapper for get that converts value into boolean""" """Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true" return str(self.get(path, default)).lower() == "true"

View File

@ -11,11 +11,7 @@ postgresql:
listen: listen:
listen_http: 0.0.0.0:9000 listen_http: 0.0.0.0:9000
listen_https: 0.0.0.0:9443 listen_https: 0.0.0.0:9443
listen_ldap: 0.0.0.0:3389
listen_ldaps: 0.0.0.0:6636
listen_radius: 0.0.0.0:1812
listen_metrics: 0.0.0.0:9300 listen_metrics: 0.0.0.0:9300
listen_debug: 0.0.0.0:9900
trusted_proxy_cidrs: trusted_proxy_cidrs:
- 127.0.0.0/8 - 127.0.0.0/8
- 10.0.0.0/8 - 10.0.0.0/8
@ -36,9 +32,6 @@ redis:
cache_timeout_policies: 300 cache_timeout_policies: 300
cache_timeout_reputation: 300 cache_timeout_reputation: 300
paths:
media: ./media
debug: false debug: false
remote_debug: false remote_debug: false
@ -84,9 +77,6 @@ ldap:
tls: tls:
ciphers: null ciphers: null
reputation:
expiry: 86400
cookie_domain: null cookie_domain: null
disable_update_check: false disable_update_check: false
disable_startup_analytics: false disable_startup_analytics: false

View File

@ -112,7 +112,7 @@ class BaseEvaluator:
@staticmethod @staticmethod
def expr_is_group_member(user: User, **group_filters) -> bool: def expr_is_group_member(user: User, **group_filters) -> bool:
"""Check if `user` is member of group with name `group_name`""" """Check if `user` is member of group with name `group_name`"""
return user.all_groups().filter(**group_filters).exists() return user.ak_groups.filter(**group_filters).exists()
@staticmethod @staticmethod
def expr_user_by(**filters) -> Optional[User]: def expr_user_by(**filters) -> Optional[User]:

View File

@ -98,7 +98,7 @@ def traces_sampler(sampling_context: dict) -> float:
def before_send(event: dict, hint: dict) -> Optional[dict]: def before_send(event: dict, hint: dict) -> Optional[dict]:
"""Check if error is database error, and ignore if so""" """Check if error is database error, and ignore if so"""
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from psycopg.errors import Error from psycopg2.errors import Error
ignored_classes = ( ignored_classes = (
# Inbuilt types # Inbuilt types

View File

@ -79,15 +79,3 @@ class TestConfig(TestCase):
config.update_from_file(file2_name) config.update_from_file(file2_name)
unlink(file_name) unlink(file_name)
unlink(file2_name) unlink(file2_name)
def test_get_int(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", 1234)
self.assertEqual(config.get_int("foo"), 1234)
def test_get_int_invalid(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", "bar")
self.assertEqual(config.get_int("foo", 1234), 1234)

View File

@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
FORK_CTX = get_context("fork") FORK_CTX = get_context("fork")
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies") CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
PROCESS_CLASS = FORK_CTX.Process PROCESS_CLASS = FORK_CTX.Process

View File

@ -1,7 +1,5 @@
"""Reputation policy API Views""" """Reputation policy API Views"""
from django.utils.translation import gettext_lazy as _
from rest_framework import mixins from rest_framework import mixins
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
@ -13,11 +11,6 @@ from authentik.policies.reputation.models import Reputation, ReputationPolicy
class ReputationPolicySerializer(PolicySerializer): class ReputationPolicySerializer(PolicySerializer):
"""Reputation Policy Serializer""" """Reputation Policy Serializer"""
def validate(self, attrs: dict) -> dict:
if not attrs.get("check_ip", False) and not attrs.get("check_username", False):
raise ValidationError(_("Either IP or Username must be checked"))
return super().validate(attrs)
class Meta: class Meta:
model = ReputationPolicy model = ReputationPolicy
fields = PolicySerializer.Meta.fields + [ fields = PolicySerializer.Meta.fields + [

View File

@ -1,33 +0,0 @@
# Generated by Django 4.2.4 on 2023-08-31 10:42
from django.db import migrations, models
import authentik.policies.reputation.models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_reputation", "0004_reputationpolicy_authentik_p_policy__8f0d70_idx"),
]
operations = [
migrations.AddField(
model_name="reputation",
name="expires",
field=models.DateTimeField(
default=authentik.policies.reputation.models.reputation_expiry
),
),
migrations.AddField(
model_name="reputation",
name="expiring",
field=models.BooleanField(default=True),
),
migrations.AlterModelOptions(
name="reputation",
options={
"verbose_name": "Reputation Score",
"verbose_name_plural": "Reputation Scores",
},
),
]

View File

@ -1,17 +1,13 @@
"""authentik reputation request policy""" """authentik reputation request policy"""
from datetime import timedelta
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.db import models
from django.db.models import Sum from django.db.models import Sum
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from structlog import get_logger from structlog import get_logger
from authentik.core.models import ExpiringModel
from authentik.lib.config import CONFIG
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.http import get_client_ip from authentik.lib.utils.http import get_client_ip
from authentik.policies.models import Policy from authentik.policies.models import Policy
@ -21,11 +17,6 @@ LOGGER = get_logger()
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/" CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
def reputation_expiry():
"""Reputation expiry"""
return now() + timedelta(seconds=CONFIG.get_int("reputation.expiry"))
class ReputationPolicy(Policy): class ReputationPolicy(Policy):
"""Return true if request IP/target username's score is below a certain threshold""" """Return true if request IP/target username's score is below a certain threshold"""
@ -68,7 +59,7 @@ class ReputationPolicy(Policy):
verbose_name_plural = _("Reputation Policies") verbose_name_plural = _("Reputation Policies")
class Reputation(ExpiringModel, SerializerModel): class Reputation(SerializerModel):
"""Reputation for user and or IP.""" """Reputation for user and or IP."""
reputation_uuid = models.UUIDField(primary_key=True, unique=True, default=uuid4) reputation_uuid = models.UUIDField(primary_key=True, unique=True, default=uuid4)
@ -78,8 +69,6 @@ class Reputation(ExpiringModel, SerializerModel):
ip_geo_data = models.JSONField(default=dict) ip_geo_data = models.JSONField(default=dict)
score = models.BigIntegerField(default=0) score = models.BigIntegerField(default=0)
expires = models.DateTimeField(default=reputation_expiry)
updated = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now_add=True)
@property @property
@ -92,6 +81,4 @@ class Reputation(ExpiringModel, SerializerModel):
return f"Reputation {self.identifier}/{self.ip} @ {self.score}" return f"Reputation {self.identifier}/{self.ip} @ {self.score}"
class Meta: class Meta:
verbose_name = _("Reputation Score")
verbose_name_plural = _("Reputation Scores")
unique_together = ("identifier", "ip") unique_together = ("identifier", "ip")

View File

@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
from authentik.stages.identification.signals import identification_failed from authentik.stages.identification.signals import identification_failed
LOGGER = get_logger() LOGGER = get_logger()
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation") CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
def update_score(request: HttpRequest, identifier: str, amount: int): def update_score(request: HttpRequest, identifier: str, amount: int):

View File

@ -3,8 +3,6 @@ from django.core.cache import cache
from django.test import RequestFactory, TestCase from django.test import RequestFactory, TestCase
from authentik.core.models import User from authentik.core.models import User
from authentik.lib.generators import generate_id
from authentik.policies.reputation.api import ReputationPolicySerializer
from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy
from authentik.policies.reputation.tasks import save_reputation from authentik.policies.reputation.tasks import save_reputation
from authentik.policies.types import PolicyRequest from authentik.policies.types import PolicyRequest
@ -63,8 +61,3 @@ class TestReputationPolicy(TestCase):
name="reputation-test", threshold=0 name="reputation-test", threshold=0
) )
self.assertTrue(policy.passes(request).passing) self.assertTrue(policy.passes(request).passing)
def test_api(self):
"""Test API Validation"""
no_toggle = ReputationPolicySerializer(data={"name": generate_id(), "threshold": -5})
self.assertFalse(no_toggle.is_valid())

View File

@ -1,6 +1,6 @@
"""id_token utils""" """id_token utils"""
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
@ -57,7 +57,7 @@ class IDToken:
# Subject, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2 # Subject, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2
sub: Optional[str] = None sub: Optional[str] = None
# Audience, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.3 # Audience, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.3
aud: Optional[Union[str, list[str]]] = None aud: Optional[str] = None
# Expiration time, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4 # Expiration time, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4
exp: Optional[int] = None exp: Optional[int] = None
# Issued at, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6 # Issued at, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6

View File

@ -2,7 +2,6 @@
import base64 import base64
import binascii import binascii
import json import json
from dataclasses import asdict
from functools import cached_property from functools import cached_property
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional from typing import Any, Optional
@ -359,7 +358,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel):
@id_token.setter @id_token.setter
def id_token(self, value: IDToken): def id_token(self, value: IDToken):
self.token = value.to_access_token(self.provider) self.token = value.to_access_token(self.provider)
self._id_token = json.dumps(asdict(value)) self._id_token = json.dumps(value.to_dict())
@property @property
def at_hash(self): def at_hash(self):
@ -401,7 +400,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
@id_token.setter @id_token.setter
def id_token(self, value: IDToken): def id_token(self, value: IDToken):
self._id_token = json.dumps(asdict(value)) self._id_token = json.dumps(value.to_dict())
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:

View File

@ -151,14 +151,6 @@ class TestTokenClientCredentials(OAuthTestCase):
) )
self.assertEqual(jwt["given_name"], self.user.name) self.assertEqual(jwt["given_name"], self.user.name)
self.assertEqual(jwt["preferred_username"], self.user.username) self.assertEqual(jwt["preferred_username"], self.user.username)
jwt = decode(
body["id_token"],
key=self.provider.signing_key.public_key,
algorithms=[alg],
audience=self.provider.client_id,
)
self.assertEqual(jwt["given_name"], self.user.name)
self.assertEqual(jwt["preferred_username"], self.user.username)
def test_successful_password(self): def test_successful_password(self):
"""test successful (password grant)""" """test successful (password grant)"""

View File

@ -375,9 +375,7 @@ class AuthorizationFlowInitView(PolicyAccessView):
): ):
self.request.session[SESSION_KEY_LAST_LOGIN_UID] = login_uid self.request.session[SESSION_KEY_LAST_LOGIN_UID] = login_uid
return self.handle_no_permission() return self.handle_no_permission()
scope_descriptions = UserInfoView().get_scope_descriptions( scope_descriptions = UserInfoView().get_scope_descriptions(self.params.scope)
self.params.scope, self.params.provider
)
# Regardless, we start the planner and return to it # Regardless, we start the planner and return to it
planner = FlowPlanner(self.provider.authorization_flow) planner = FlowPlanner(self.provider.authorization_flow)
planner.allow_empty_flows = True planner.allow_empty_flows = True

View File

@ -55,7 +55,7 @@ def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]:
if not app: if not app:
return None return None
scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider) scope_descriptions = UserInfoView().get_scope_descriptions(token.scope)
planner = FlowPlanner(token.provider.authorization_flow) planner = FlowPlanner(token.provider.authorization_flow)
planner.allow_empty_flows = True planner.allow_empty_flows = True
try: try:

View File

@ -40,14 +40,10 @@ class UserInfoView(View):
token: Optional[RefreshToken] token: Optional[RefreshToken]
def get_scope_descriptions( def get_scope_descriptions(self, scopes: list[str]) -> list[PermissionDict]:
self, scopes: list[str], provider: OAuth2Provider
) -> list[PermissionDict]:
"""Get a list of all Scopes's descriptions""" """Get a list of all Scopes's descriptions"""
scope_descriptions = [] scope_descriptions = []
for scope in ScopeMapping.objects.filter(scope_name__in=scopes, provider=provider).order_by( for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by("scope_name"):
"scope_name"
):
scope_descriptions.append(PermissionDict(id=scope.scope_name, name=scope.description)) scope_descriptions.append(PermissionDict(id=scope.scope_name, name=scope.description))
# GitHub Compatibility Scopes are handled differently, since they required custom paths # GitHub Compatibility Scopes are handled differently, since they required custom paths
# Hence they don't exist as Scope objects # Hence they don't exist as Scope objects

View File

@ -59,9 +59,7 @@ class ProxyProviderSerializer(ProviderSerializer):
attrs.get("mode", ProxyMode.PROXY) == ProxyMode.PROXY attrs.get("mode", ProxyMode.PROXY) == ProxyMode.PROXY
and attrs.get("internal_host", "") == "" and attrs.get("internal_host", "") == ""
): ):
raise ValidationError( raise ValidationError(_("Internal host cannot be empty when forward auth is disabled."))
{"internal_host": _("Internal host cannot be empty when forward auth is disabled.")}
)
return attrs return attrs
def create(self, validated_data: dict): def create(self, validated_data: dict):

View File

@ -69,7 +69,7 @@ class ProxyProviderTests(APITestCase):
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{"internal_host": ["Internal host cannot be empty when forward auth is disabled."]}, {"non_field_errors": ["Internal host cannot be empty when forward auth is disabled."]},
) )
def test_create_defaults(self): def test_create_defaults(self):

View File

@ -13,9 +13,10 @@ from rest_framework.decorators import action
from rest_framework.fields import CharField, FileField, SerializerMethodField from rest_framework.fields import CharField, FileField, SerializerMethodField
from rest_framework.parsers import MultiPartParser from rest_framework.parsers import MultiPartParser
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from rest_framework.relations import SlugRelatedField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import PrimaryKeyRelatedField, ValidationError from rest_framework.serializers import ValidationError
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -167,8 +168,10 @@ class SAMLProviderImportSerializer(PassiveSerializer):
"""Import saml provider from XML Metadata""" """Import saml provider from XML Metadata"""
name = CharField(required=True) name = CharField(required=True)
authorization_flow = PrimaryKeyRelatedField( # Using SlugField because https://github.com/OpenAPITools/openapi-generator/issues/3278
authorization_flow = SlugRelatedField(
queryset=Flow.objects.filter(designation=FlowDesignation.AUTHORIZATION), queryset=Flow.objects.filter(designation=FlowDesignation.AUTHORIZATION),
slug_field="slug",
) )
file = FileField() file = FileField()

View File

@ -89,7 +89,7 @@ class TestSAMLProviderAPI(APITestCase):
{ {
"file": metadata, "file": metadata,
"name": generate_id(), "name": generate_id(),
"authorization_flow": create_test_flow(FlowDesignation.AUTHORIZATION).pk, "authorization_flow": create_test_flow(FlowDesignation.AUTHORIZATION).slug,
}, },
format="multipart", format="multipart",
) )
@ -106,7 +106,7 @@ class TestSAMLProviderAPI(APITestCase):
{ {
"file": metadata, "file": metadata,
"name": generate_id(), "name": generate_id(),
"authorization_flow": create_test_flow().pk, "authorization_flow": create_test_flow().slug,
}, },
format="multipart", format="multipart",
) )

View File

@ -68,7 +68,7 @@ class SCIMClient(Generic[T, SchemaType]):
"""Get Service provider config""" """Get Service provider config"""
default_config = ServiceProviderConfiguration.default() default_config = ServiceProviderConfiguration.default()
try: try:
return ServiceProviderConfiguration.model_validate( return ServiceProviderConfiguration.parse_obj(
self._request("GET", "/ServiceProviderConfig") self._request("GET", "/ServiceProviderConfig")
) )
except (ValidationError, SCIMRequestException) as exc: except (ValidationError, SCIMRequestException) as exc:

View File

@ -74,7 +74,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
if not raw_scim_group: if not raw_scim_group:
raise StopSync(ValueError("No group mappings configured"), obj) raise StopSync(ValueError("No group mappings configured"), obj)
try: try:
scim_group = SCIMGroupSchema.model_validate(delete_none_values(raw_scim_group)) scim_group = SCIMGroupSchema.parse_obj(delete_none_values(raw_scim_group))
except ValidationError as exc: except ValidationError as exc:
raise StopSync(exc, obj) from exc raise StopSync(exc, obj) from exc
if not scim_group.externalId: if not scim_group.externalId:
@ -99,8 +99,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
response = self._request( response = self._request(
"POST", "POST",
"/Groups", "/Groups",
json=scim_group.model_dump( data=scim_group.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )
@ -114,8 +113,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
return self._request( return self._request(
"PUT", "PUT",
f"/Groups/{scim_group.id}", f"/Groups/{scim_group.id}",
json=scim_group.model_dump( data=scim_group.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )
@ -162,13 +160,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
*ops: PatchOperation, *ops: PatchOperation,
): ):
req = PatchRequest(Operations=ops) req = PatchRequest(Operations=ops)
self._request( self._request("PATCH", f"/Groups/{group_id}", data=req.json())
"PATCH",
f"/Groups/{group_id}",
json=req.model_dump(
mode="json",
),
)
def _patch_add_users(self, group: Group, users_set: set[int]): def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group""" """Add users in users_set to group"""

View File

@ -52,7 +52,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
class PatchRequest(BasePatchRequest): class PatchRequest(BasePatchRequest):
"""PatchRequest which correctly sets schemas""" """PatchRequest which correctly sets schemas"""
schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) schemas: tuple[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"]
class SCIMError(BaseSCIMError): class SCIMError(BaseSCIMError):

View File

@ -64,7 +64,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
if not raw_scim_user: if not raw_scim_user:
raise StopSync(ValueError("No user mappings configured"), obj) raise StopSync(ValueError("No user mappings configured"), obj)
try: try:
scim_user = SCIMUserSchema.model_validate(delete_none_values(raw_scim_user)) scim_user = SCIMUserSchema.parse_obj(delete_none_values(raw_scim_user))
except ValidationError as exc: except ValidationError as exc:
raise StopSync(exc, obj) from exc raise StopSync(exc, obj) from exc
if not scim_user.externalId: if not scim_user.externalId:
@ -77,8 +77,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
response = self._request( response = self._request(
"POST", "POST",
"/Users", "/Users",
json=scim_user.model_dump( data=scim_user.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )
@ -91,8 +90,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
self._request( self._request(
"PUT", "PUT",
f"/Users/{connection.id}", f"/Users/{connection.id}",
json=scim_user.model_dump( data=scim_user.json(
mode="json",
exclude_unset=True, exclude_unset=True,
), ),
) )

View File

@ -23,8 +23,6 @@ def post_save_provider(sender: type[Model], instance, created: bool, **_):
@receiver(post_save, sender=Group) @receiver(post_save, sender=Group)
def post_save_scim(sender: type[Model], instance: User | Group, created: bool, **_): def post_save_scim(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler""" """Post save handler"""
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.add.value) scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.add.value)
@ -32,8 +30,6 @@ def post_save_scim(sender: type[Model], instance: User | Group, created: bool, *
@receiver(pre_delete, sender=Group) @receiver(pre_delete, sender=Group)
def pre_delete_scim(sender: type[Model], instance: User | Group, **_): def pre_delete_scim(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler""" """Pre-delete handler"""
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.remove.value) scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.remove.value)
@ -44,8 +40,6 @@ def m2m_changed_scim(
"""Sync group membership""" """Sync group membership"""
if action not in ["post_add", "post_remove"]: if action not in ["post_add", "post_remove"]:
return return
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks # reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups # non-reverse: instance is a User, pk_set is a list of groups
if reverse: if reverse:

View File

@ -47,7 +47,6 @@ class SCIMMembershipTests(TestCase):
def test_member_add(self): def test_member_add(self):
"""Test member add""" """Test member add"""
config = ServiceProviderConfiguration.default() config = ServiceProviderConfiguration.default()
# pylint: disable=assigning-non-slot
config.patch.supported = True config.patch.supported = True
user_scim_id = generate_id() user_scim_id = generate_id()
group_scim_id = generate_id() group_scim_id = generate_id()
@ -61,7 +60,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.post( mocker.post(
"https://localhost/Users", "https://localhost/Users",
@ -105,7 +104,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.patch( mocker.patch(
f"https://localhost/Groups/{group_scim_id}", f"https://localhost/Groups/{group_scim_id}",
@ -132,7 +131,6 @@ class SCIMMembershipTests(TestCase):
def test_member_remove(self): def test_member_remove(self):
"""Test member remove""" """Test member remove"""
config = ServiceProviderConfiguration.default() config = ServiceProviderConfiguration.default()
# pylint: disable=assigning-non-slot
config.patch.supported = True config.patch.supported = True
user_scim_id = generate_id() user_scim_id = generate_id()
group_scim_id = generate_id() group_scim_id = generate_id()
@ -146,7 +144,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.post( mocker.post(
"https://localhost/Users", "https://localhost/Users",
@ -190,7 +188,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.patch( mocker.patch(
f"https://localhost/Groups/{group_scim_id}", f"https://localhost/Groups/{group_scim_id}",
@ -217,7 +215,7 @@ class SCIMMembershipTests(TestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.get( mocker.get(
"https://localhost/ServiceProviderConfig", "https://localhost/ServiceProviderConfig",
json=config.model_dump(), json=config.dict(),
) )
mocker.patch( mocker.patch(
f"https://localhost/Groups/{group_scim_id}", f"https://localhost/Groups/{group_scim_id}",

View File

@ -44,11 +44,7 @@ def config_loggers(*args, **kwargs):
def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs): def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs):
"""Log task_id after it was published""" """Log task_id after it was published"""
info = headers if "task" in headers else body info = headers if "task" in headers else body
LOGGER.info( LOGGER.info("Task published", task_id=info.get("id", ""), task_name=info.get("task", ""))
"Task published",
task_id=info.get("id", "").replace("-", ""),
task_name=info.get("task", ""),
)
@task_prerun.connect @task_prerun.connect
@ -63,9 +59,7 @@ def task_prerun_hook(task_id: str, task, *args, **kwargs):
def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs): def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs):
"""Log task_id on worker""" """Log task_id on worker"""
CTX_TASK_ID.set(...) CTX_TASK_ID.set(...)
LOGGER.info( LOGGER.info("Task finished", task_id=task_id, task_name=task.__name__, state=state)
"Task finished", task_id=task_id.replace("-", ""), task_name=task.__name__, state=state
)
@task_failure.connect @task_failure.connect

View File

@ -2,7 +2,7 @@
from functools import lru_cache from functools import lru_cache
from uuid import uuid4 from uuid import uuid4
from psycopg import connect from psycopg2 import connect
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -30,7 +30,7 @@ def get_install_id_raw():
user=CONFIG.get("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=CONFIG.get_int("postgresql.port"), port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.get("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),

View File

@ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False):
_redis_url = ( _redis_url = (
f"{_redis_protocol_prefix}:" f"{_redis_protocol_prefix}:"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{CONFIG.get_int('redis.port')}" f"{int(CONFIG.get('redis.port'))}"
) )
CACHES = { CACHES = {
"default": { "default": {
"BACKEND": "django_redis.cache.RedisCache", "BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300), "TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache", "KEY_PREFIX": "authentik_cache",
} }
@ -274,7 +274,7 @@ DATABASES = {
"NAME": CONFIG.get("postgresql.name"), "NAME": CONFIG.get("postgresql.name"),
"USER": CONFIG.get("postgresql.user"), "USER": CONFIG.get("postgresql.user"),
"PASSWORD": CONFIG.get("postgresql.password"), "PASSWORD": CONFIG.get("postgresql.password"),
"PORT": CONFIG.get_int("postgresql.port"), "PORT": int(CONFIG.get("postgresql.port")),
"SSLMODE": CONFIG.get("postgresql.sslmode"), "SSLMODE": CONFIG.get("postgresql.sslmode"),
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
"SSLCERT": CONFIG.get("postgresql.sslcert"), "SSLCERT": CONFIG.get("postgresql.sslcert"),
@ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
# loads the config directly from CONFIG # loads the config directly from CONFIG
# See authentik/stages/email/models.py, line 105 # See authentik/stages/email/models.py, line 105
EMAIL_HOST = CONFIG.get("email.host") EMAIL_HOST = CONFIG.get("email.host")
EMAIL_PORT = CONFIG.get_int("email.port") EMAIL_PORT = int(CONFIG.get("email.port"))
EMAIL_HOST_USER = CONFIG.get("email.username") EMAIL_HOST_USER = CONFIG.get("email.username")
EMAIL_HOST_PASSWORD = CONFIG.get("email.password") EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
EMAIL_TIMEOUT = CONFIG.get_int("email.timeout") EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
DEFAULT_FROM_EMAIL = CONFIG.get("email.from") DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
SERVER_EMAIL = DEFAULT_FROM_EMAIL SERVER_EMAIL = DEFAULT_FROM_EMAIL
EMAIL_SUBJECT_PREFIX = "[authentik] " EMAIL_SUBJECT_PREFIX = "[authentik] "
@ -411,7 +411,7 @@ LOGGING = {
"json": { "json": {
"()": structlog.stdlib.ProcessorFormatter, "()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.processors.JSONRenderer(sort_keys=True), "processor": structlog.processors.JSONRenderer(sort_keys=True),
"foreign_pre_chain": LOG_PRE_CHAIN + [structlog.processors.dict_tracebacks], "foreign_pre_chain": LOG_PRE_CHAIN,
}, },
"console": { "console": {
"()": structlog.stdlib.ProcessorFormatter, "()": structlog.stdlib.ProcessorFormatter,

View File

@ -44,11 +44,7 @@ class LDAPSourceSerializer(SourceSerializer):
sources = sources.exclude(pk=self.instance.pk) sources = sources.exclude(pk=self.instance.pk)
if sources.exists(): if sources.exists():
raise ValidationError( raise ValidationError(
{ "Only a single LDAP Source with password synchronization is allowed"
"sync_users_password": (
"Only a single LDAP Source with password synchronization is allowed"
)
}
) )
return super().validate(attrs) return super().validate(attrs)

View File

@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
types_only=False, types_only=False,
get_operational_attributes=False, get_operational_attributes=False,
controls=None, controls=None,
paged_size=CONFIG.get_int("ldap.page_size", 50), paged_size=int(CONFIG.get("ldap.page_size", 50)),
paged_criticality=False, paged_criticality=False,
): ):
"""Search in pages, returns each page""" """Search in pages, returns each page"""

View File

@ -18,9 +18,6 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
return "groups" return "groups"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,

View File

@ -24,9 +24,6 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
return "membership" return "membership"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,

View File

@ -20,9 +20,6 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
return "users" return "users"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_users:
self.message("User syncing is disabled for this Source")
return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_users, search_base=self.base_dn_users,
search_filter=self._source.user_object_filter, search_filter=self._source.user_object_filter,

View File

@ -45,11 +45,7 @@ class FreeIPA(BaseLDAPSynchronizer):
# 389-ds and this will trigger regardless # 389-ds and this will trigger regardless
if "nsaccountlock" not in attributes: if "nsaccountlock" not in attributes:
return return
# For some reason, nsaccountlock is not defined properly in the schema as bool is_active = attributes.get("nsaccountlock", False)
# hence we get it as a list of strings
_is_active = str(self._flatten(attributes.get("nsaccountlock", ["FALSE"])))
# So we have to attempt to convert it to a bool
is_active = _is_active.lower() == "true"
if is_active != user.is_active: if is_active != user.is_active:
user.is_active = is_active user.is_active = is_active
user.save() user.save()

View File

@ -33,13 +33,7 @@ def ldap_sync_all():
ldap_sync_single(source.pk) ldap_sync_single(source.pk)
@CELERY_APP.task( @CELERY_APP.task()
# We take the configured hours timeout time by 2.5 as we run user and
# group in parallel and then membership, so 2x is to cover the serial tasks,
# and 0.5x on top of that to give some more leeway
soft_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 2.5,
task_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 2.5,
)
def ldap_sync_single(source_pk: str): def ldap_sync_single(source_pk: str):
"""Sync a single source""" """Sync a single source"""
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
@ -65,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
signatures = [] signatures = []
for page in sync_inst.get_objects(): for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")))
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
signatures.append(page_sync) signatures.append(page_sync)
return signatures return signatures
@ -74,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
@CELERY_APP.task( @CELERY_APP.task(
bind=True, bind=True,
base=MonitoredTask, base=MonitoredTask,
soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
) )
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source""" """Synchronization of an LDAP Source"""
self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source: if not source:
# Because the source couldn't be found, we don't have a UID # Because the source couldn't be found, we don't have a UID

View File

@ -1,111 +0,0 @@
"""ldap testing utils"""
from ldap3 import MOCK_SYNC, OFFLINE_DS389_1_3_3, Connection, Server
def mock_freeipa_connection(password: str) -> Connection:
"""Create mock FreeIPA-ish connection"""
server = Server("my_fake_server", get_info=OFFLINE_DS389_1_3_3)
_pass = "foo" # noqa # nosec
connection = Connection(
server,
user="cn=my_user,dc=goauthentik,dc=io",
password=_pass,
client_strategy=MOCK_SYNC,
)
# Entry for password checking
connection.strategy.add_entry(
"cn=user,ou=users,dc=goauthentik,dc=io",
{
"name": "test-user",
"uid": "unique-test-group",
"objectClass": "person",
"displayName": "Erin M. Hagens",
},
)
connection.strategy.add_entry(
"cn=group1,ou=groups,dc=goauthentik,dc=io",
{
"cn": "group1",
"uid": "unique-test-group",
"objectClass": "groupOfNames",
"member": ["cn=user0,ou=users,dc=goauthentik,dc=io"],
},
)
# Group without SID
connection.strategy.add_entry(
"cn=group2,ou=groups,dc=goauthentik,dc=io",
{
"cn": "group2",
"objectClass": "groupOfNames",
},
)
connection.strategy.add_entry(
"cn=user0,ou=users,dc=goauthentik,dc=io",
{
"userPassword": password,
"name": "user0_sn",
"uid": "user0_sn",
"objectClass": "person",
},
)
# User without SID
connection.strategy.add_entry(
"cn=user1,ou=users,dc=goauthentik,dc=io",
{
"userPassword": "test1111",
"name": "user1_sn",
"objectClass": "person",
},
)
# Duplicate users
connection.strategy.add_entry(
"cn=user2,ou=users,dc=goauthentik,dc=io",
{
"userPassword": "test2222",
"name": "user2_sn",
"uid": "unique-test2222",
"objectClass": "person",
},
)
connection.strategy.add_entry(
"cn=user3,ou=users,dc=goauthentik,dc=io",
{
"userPassword": "test2222",
"name": "user2_sn",
"uid": "unique-test2222",
"objectClass": "person",
},
)
# Group with posixGroup and memberUid
connection.strategy.add_entry(
"cn=group-posix,ou=groups,dc=goauthentik,dc=io",
{
"cn": "group-posix",
"objectClass": "posixGroup",
"memberUid": ["user-posix"],
},
)
# User with posixAccount
connection.strategy.add_entry(
"cn=user-posix,ou=users,dc=goauthentik,dc=io",
{
"userPassword": password,
"uid": "user-posix",
"cn": "user-posix",
"objectClass": "posixAccount",
},
)
# Locked out user
connection.strategy.add_entry(
"cn=user-nsaccountlock,ou=users,dc=goauthentik,dc=io",
{
"userPassword": password,
"uid": "user-nsaccountlock",
"cn": "user-nsaccountlock",
"objectClass": "person",
"nsaccountlock": ["TRUE"],
},
)
connection.bind()
return connection

View File

@ -4,7 +4,7 @@ from ldap3 import MOCK_SYNC, OFFLINE_SLAPD_2_4, Connection, Server
def mock_slapd_connection(password: str) -> Connection: def mock_slapd_connection(password: str) -> Connection:
"""Create mock SLAPD connection""" """Create mock AD connection"""
server = Server("my_fake_server", get_info=OFFLINE_SLAPD_2_4) server = Server("my_fake_server", get_info=OFFLINE_SLAPD_2_4)
_pass = "foo" # noqa # nosec _pass = "foo" # noqa # nosec
connection = Connection( connection = Connection(

View File

@ -17,7 +17,6 @@ from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
LDAP_PASSWORD = generate_key() LDAP_PASSWORD = generate_key()
@ -121,23 +120,6 @@ class LDAPSyncTests(TestCase):
self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists())
def test_sync_users_freeipa_ish(self):
"""Test user sync (FreeIPA-ish), mainly testing vendor quirks"""
self.source.object_uniqueness_field = "uid"
self.source.property_mappings.set(
LDAPPropertyMapping.objects.filter(
Q(managed__startswith="goauthentik.io/sources/ldap/default")
| Q(managed__startswith="goauthentik.io/sources/ldap/openldap")
)
)
self.source.save()
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists())
def test_sync_groups_ad(self): def test_sync_groups_ad(self):
"""Test group sync""" """Test group sync"""
self.source.property_mappings.set( self.source.property_mappings.set(

View File

@ -62,8 +62,7 @@ class OAuthSourceSerializer(SourceSerializer):
well_known_config = session.get(well_known) well_known_config = session.get(well_known)
well_known_config.raise_for_status() well_known_config.raise_for_status()
except RequestException as exc: except RequestException as exc:
text = exc.response.text if exc.response else str(exc) raise ValidationError(exc.response.text)
raise ValidationError({"oidc_well_known_url": text})
config = well_known_config.json() config = well_known_config.json()
try: try:
attrs["authorization_url"] = config["authorization_endpoint"] attrs["authorization_url"] = config["authorization_endpoint"]
@ -71,9 +70,7 @@ class OAuthSourceSerializer(SourceSerializer):
attrs["profile_url"] = config["userinfo_endpoint"] attrs["profile_url"] = config["userinfo_endpoint"]
attrs["oidc_jwks_url"] = config["jwks_uri"] attrs["oidc_jwks_url"] = config["jwks_uri"]
except (IndexError, KeyError) as exc: except (IndexError, KeyError) as exc:
raise ValidationError( raise ValidationError(f"Invalid well-known configuration: {exc}")
{"oidc_well_known_url": f"Invalid well-known configuration: {exc}"}
)
jwks_url = attrs.get("oidc_jwks_url") jwks_url = attrs.get("oidc_jwks_url")
if jwks_url and jwks_url != "": if jwks_url and jwks_url != "":
@ -81,8 +78,7 @@ class OAuthSourceSerializer(SourceSerializer):
jwks_config = session.get(jwks_url) jwks_config = session.get(jwks_url)
jwks_config.raise_for_status() jwks_config.raise_for_status()
except RequestException as exc: except RequestException as exc:
text = exc.response.text if exc.response else str(exc) raise ValidationError(exc.response.text)
raise ValidationError({"jwks_url": text})
config = jwks_config.json() config = jwks_config.json()
attrs["oidc_jwks"] = config attrs["oidc_jwks"] = config

View File

@ -30,7 +30,7 @@ def check_plex_token(self: MonitoredTask, source_slug: int):
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."])) self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."]))
except RequestException as exc: except RequestException as exc:
error = exception_to_string(exc) error = exception_to_string(exc)
if len(source.plex_token) > 0: if len(source.plex_token) < 1:
error = error.replace(source.plex_token, "$PLEX_TOKEN") error = error.replace(source.plex_token, "$PLEX_TOKEN")
self.set_status( self.set_status(
TaskResult( TaskResult(

View File

@ -18,12 +18,7 @@ class AuthenticatorStaticStageSerializer(StageSerializer):
class Meta: class Meta:
model = AuthenticatorStaticStage model = AuthenticatorStaticStage
fields = StageSerializer.Meta.fields + [ fields = StageSerializer.Meta.fields + ["configure_flow", "friendly_name", "token_count"]
"configure_flow",
"friendly_name",
"token_count",
"token_length",
]
class AuthenticatorStaticStageViewSet(UsedByMixin, ModelViewSet): class AuthenticatorStaticStageViewSet(UsedByMixin, ModelViewSet):

View File

@ -1,22 +0,0 @@
# Generated by Django 4.2.4 on 2023-08-17 17:34
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_stages_authenticator_static", "0006_authenticatorstaticstage_friendly_name"),
]
operations = [
migrations.AddField(
model_name="authenticatorstaticstage",
name="token_length",
field=models.PositiveIntegerField(default=12),
),
migrations.AlterField(
model_name="authenticatorstaticstage",
name="token_count",
field=models.PositiveIntegerField(default=6),
),
]

Some files were not shown because too many files have changed in this diff Show More