Compare commits

..

17 Commits

Author SHA1 Message Date
1dc4fbbb2b Attempting coercion to ESM with Vite & Wdio 2024-08-09 10:55:44 -07:00
3332de267d Testing needs to be able to import from dependent packages. 2024-08-09 10:53:20 -07:00
ab366d0ec2 Testing needs to be able to import from dependent packages. 2024-08-09 10:50:51 -07:00
ac162582aa Modernization continues. 2024-08-09 10:43:28 -07:00
7d82e029d5 wdio does not need the storybook cssimport hack. 2024-08-09 10:37:17 -07:00
9b40ecb023 web: restricted all eslints to local checks only and blocked them from cache analysis 2024-08-09 09:56:55 -07:00
0cc0fdaae3 Not ready for primetime. 2024-08-09 09:35:46 -07:00
b55b168718 web: move common into its own package.
```
$ mkdir ./packages/common
$ git mv ./src/common ./packages/common/src
```

... and then added all of the boilerplate needed to drive with Wireit, build with ESlint, typecheck
with TSC, and then spell check documentation and comments, security checks of package.json and
package-lock.json, format.

... and _then_ fix all of the minor, nitpicky things ESLint 9 found in the package.

... and _then_ wire the whole thing into our build so that we can find it as a package, removing
it as an alias from the base package definition and turning it into a workspace.  Although it is
a workspace package, it's currently configured to build completely independently.

It could be published as an independent NPM package, although I don't recommended that at this time.

I've wanted to break the UI up into smaller, more digestible chunks for awhile, but was always
reluctant to, since I didn't want to mess with other teams' mental models of the code layout.
@Beryju, seeing the success of the Simple Flow Executor as an independent package, thought it might
be worthwhile to see what effort it took to break the graph of our independent apps (User, Flow, and
Admin) and their dependencies (Common <- Elements <- Components, Common <- Locales) into packages.

Turns out, it's not too bad.  It's going to be fiddly for awhile until things settle down, but
overall the experiment has been a success.

The `tsconfig.json` doesn't refer to the base because we want this to build independently; tooling
will be needed to ensure all of our `tsconfig` files in the future will be consistent across all
packages.

- We can use the ESLint boilerplate as-is.
- We have to run TSC as a separate (but fortunately parallel) build step, as client code will need
  the built types. Final builds will be fractionally slower, but Wireit can detect when a monorepo
  package is unchanged and can skip rebuilding `common` if it's not needed, so the development loop
  will be faster.
- The ESBuild boilerplate is different for libraries with UI, libraries without UI (like this one),
  and apps, and we'll have to have three different routines for them. Once we are building
  independent _apps_, getting them into the `dist` folder will be an interesting challenge; we may
  end up with two different builds, one to bundle it in *in the app*, and another to bundle it *for
  Django*. That's mostly an issue of targeting and integration, and shouldn't take too much time.
- Spelling, formatting, and package checking aren't affected.
- `Locales` is our biggest challenge, as usual. I have found only [one article on it
  anywhere](https://medium.com/tech-at-zet/streamlining-localization-in-a-monorepo-using-i18n-js-e7c521ff69d4),
  and it recommends creating a single package in which to keep all of the localizations and the
  localization machinery. That seems like a sound approach, but we haven't (yet) gotten there.

`common` is a bit of a junk drawer: there are global utilities in there, there are app-specific
helpers, there are plug-in specific helpers, and so on. Figuring out exactly what does what and
making more specific packages may be in our future.
2024-08-09 08:38:30 -07:00
c46dc8f290 Not sure how that happened. 2024-08-08 16:10:07 -07:00
e48da3520c Fix type checking issues at the TSC level. 2024-08-08 16:03:37 -07:00
1ec4652c60 Fix dependent types needed before attempting typecheck. 2024-08-08 15:51:09 -07:00
e375646705 Made linting the subpackages a requirement of success. 2024-08-08 15:47:54 -07:00
b84652d9d3 Fix eslint so it only lints the local package. Other packages have their own responsibilities. 2024-08-08 15:44:43 -07:00
74b8da28ca Added common:build" to the list of dependencies for building, which is what you want, right? 2024-08-08 15:32:11 -07:00
9084c7c6b4 web: all the basic commands are working: build, build types, lint source, lint lockfile, lint packagefile, lint types, lint spelling, format source, format package. 2024-08-08 15:08:05 -07:00
7a0b227b46 Interim commit 2024-08-08 14:25:14 -07:00
cc9128fd46 Move begun; sfe cleanup completed. 2024-08-08 11:14:50 -07:00
799 changed files with 23311 additions and 52724 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2024.8.2 current_version = 2024.6.3
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -29,9 +29,9 @@ outputs:
imageTags: imageTags:
description: "Docker image tags" description: "Docker image tags"
value: ${{ steps.ev.outputs.imageTags }} value: ${{ steps.ev.outputs.imageTags }}
attestImageNames: imageNames:
description: "Docker image names used for attestation" description: "Docker image names"
value: ${{ steps.ev.outputs.attestImageNames }} value: ${{ steps.ev.outputs.imageNames }}
imageMainTag: imageMainTag:
description: "Docker image main tag" description: "Docker image main tag"
value: ${{ steps.ev.outputs.imageMainTag }} value: ${{ steps.ev.outputs.imageMainTag }}

View File

@ -51,24 +51,15 @@ else:
] ]
image_main_tag = image_tags[0].split(":")[-1] image_main_tag = image_tags[0].split(":")[-1]
image_tags_rendered = ",".join(image_tags)
image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags))
def get_attest_image_names(image_with_tags: list[str]):
"""Attestation only for GHCR"""
image_tags = []
for image_name in set(name.split(":")[0] for name in image_with_tags):
if not image_name.startswith("ghcr.io"):
continue
image_tags.append(image_name)
return ",".join(set(image_tags))
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print(f"shouldBuild={should_build}", file=_output) print(f"shouldBuild={should_build}", file=_output)
print(f"sha={sha}", file=_output) print(f"sha={sha}", file=_output)
print(f"version={version}", file=_output) print(f"version={version}", file=_output)
print(f"prerelease={prerelease}", file=_output) print(f"prerelease={prerelease}", file=_output)
print(f"imageTags={','.join(image_tags)}", file=_output) print(f"imageTags={image_tags_rendered}", file=_output)
print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output) print(f"imageNames={image_names_rendered}", file=_output)
print(f"imageMainTag={image_main_tag}", file=_output) print(f"imageMainTag={image_main_tag}", file=_output)
print(f"imageMainName={image_tags[0]}", file=_output) print(f"imageMainName={image_tags[0]}", file=_output)

View File

@ -44,11 +44,9 @@ updates:
- "babel-*" - "babel-*"
eslint: eslint:
patterns: patterns:
- "@eslint/*"
- "@typescript-eslint/*" - "@typescript-eslint/*"
- "eslint-*"
- "eslint" - "eslint"
- "typescript-eslint" - "eslint-*"
storybook: storybook:
patterns: patterns:
- "@storybook/*" - "@storybook/*"
@ -56,16 +54,10 @@ updates:
esbuild: esbuild:
patterns: patterns:
- "@esbuild/*" - "@esbuild/*"
- "esbuild*"
rollup: rollup:
patterns: patterns:
- "@rollup/*" - "@rollup/*"
- "rollup-*" - "rollup-*"
- "rollup*"
swc:
patterns:
- "@swc/*"
- "swc-*"
wdio: wdio:
patterns: patterns:
- "@wdio/*" - "@wdio/*"

View File

@ -40,7 +40,7 @@ jobs:
run: | run: |
export VERSION=`node -e 'console.log(require("../gen-ts-api/package.json").version)'` export VERSION=`node -e 'console.log(require("../gen-ts-api/package.json").version)'`
npm i @goauthentik/api@$VERSION npm i @goauthentik/api@$VERSION
- uses: peter-evans/create-pull-request@v7 - uses: peter-evans/create-pull-request@v6
id: cpr id: cpr
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}

View File

@ -120,12 +120,6 @@ jobs:
with: with:
flags: unit flags: unit
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: unit
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
test-integration: test-integration:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 30 timeout-minutes: 30
@ -144,12 +138,6 @@ jobs:
with: with:
flags: integration flags: integration
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: integration
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
test-e2e: test-e2e:
name: test-e2e (${{ matrix.job.name }}) name: test-e2e (${{ matrix.job.name }})
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -202,12 +190,6 @@ jobs:
with: with:
flags: e2e flags: e2e
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: e2e
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
ci-core-mark: ci-core-mark:
needs: needs:
- lint - lint
@ -279,7 +261,7 @@ jobs:
id: attest id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }} if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
pr-comment: pr-comment:

View File

@ -31,7 +31,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v6 uses: golangci/golangci-lint-action@v6
with: with:
version: latest version: v1.54.2
args: --timeout 5000s --verbose args: --timeout 5000s --verbose
skip-cache: true skip-cache: true
test-unittest: test-unittest:
@ -115,7 +115,7 @@ jobs:
id: attest id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }} if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
build-binary: build-binary:

View File

@ -45,6 +45,7 @@ jobs:
- working-directory: ${{ matrix.project }}/ - working-directory: ${{ matrix.project }}/
run: | run: |
npm ci npm ci
${{ matrix.extra_setup }}
- name: Generate API - name: Generate API
run: make gen-client-ts run: make gen-client-ts
- name: Lint - name: Lint
@ -91,4 +92,4 @@ jobs:
run: make gen-client-ts run: make gen-client-ts
- name: test - name: test
working-directory: web/ working-directory: web/
run: npm run test || exit 0 run: npm run test

View File

@ -24,7 +24,7 @@ jobs:
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- run: poetry run ak update_webauthn_mds - run: poetry run ak update_webauthn_mds
- uses: peter-evans/create-pull-request@v7 - uses: peter-evans/create-pull-request@v6
id: cpr id: cpr
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}

View File

@ -42,7 +42,7 @@ jobs:
with: with:
githubToken: ${{ steps.generate_token.outputs.token }} githubToken: ${{ steps.generate_token.outputs.token }}
compressOnly: ${{ github.event_name != 'pull_request' }} compressOnly: ${{ github.event_name != 'pull_request' }}
- uses: peter-evans/create-pull-request@v7 - uses: peter-evans/create-pull-request@v6
if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}" if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}"
id: cpr id: cpr
with: with:

View File

@ -51,14 +51,12 @@ jobs:
secrets: | secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }} tags: ${{ steps.ev.outputs.imageTags }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
- uses: actions/attest-build-provenance@v1 - uses: actions/attest-build-provenance@v1
id: attest id: attest
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
build-outpost: build-outpost:
@ -113,8 +111,6 @@ jobs:
id: push id: push
with: with:
push: true push: true
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }} tags: ${{ steps.ev.outputs.imageTags }}
file: ${{ matrix.type }}.Dockerfile file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
@ -122,7 +118,7 @@ jobs:
- uses: actions/attest-build-provenance@v1 - uses: actions/attest-build-provenance@v1
id: attest id: attest
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
build-outpost-binary: build-outpost-binary:

View File

@ -32,7 +32,7 @@ jobs:
poetry run ak compilemessages poetry run ak compilemessages
make web-check-compile make web-check-compile
- name: Create Pull Request - name: Create Pull Request
uses: peter-evans/create-pull-request@v7 uses: peter-evans/create-pull-request@v6
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
branch: extract-compile-backend-translation branch: extract-compile-backend-translation

View File

@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1 # syntax=docker/dockerfile:1
# Stage 1: Build website # Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS website-builder FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as website-builder
ENV NODE_ENV=production ENV NODE_ENV=production
@ -20,7 +20,7 @@ COPY ./SECURITY.md /work/
RUN npm run build-bundled RUN npm run build-bundled
# Stage 2: Build webui # Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS web-builder FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as web-builder
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
@ -43,7 +43,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build RUN npm run build
# Stage 3: Build go proxy # Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.23-fips-bookworm AS go-builder FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
ARG TARGETOS ARG TARGETOS
ARG TARGETARCH ARG TARGETARCH
@ -80,7 +80,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
go build -o /go/authentik ./cmd/server go build -o /go/authentik ./cmd/server
# Stage 4: MaxMind GeoIP # Stage 4: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 AS geoip FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN" ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
ENV GEOIPUPDATE_VERBOSE="1" ENV GEOIPUPDATE_VERBOSE="1"
@ -94,10 +94,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 5: Python dependencies # Stage 5: Python dependencies
FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS python-deps FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps
ARG TARGETARCH
ARG TARGETVARIANT
WORKDIR /ak-root/poetry WORKDIR /ak-root/poetry
@ -124,17 +121,17 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
pip install --force-reinstall /wheels/*" pip install --force-reinstall /wheels/*"
# Stage 6: Run # Stage 6: Run
FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS final-image FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image
ARG VERSION
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ARG VERSION
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
LABEL org.opencontainers.image.url=https://goauthentik.io LABEL org.opencontainers.image.url https://goauthentik.io
LABEL org.opencontainers.image.description="goauthentik.io Main server image, see https://goauthentik.io for more info." LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info.
LABEL org.opencontainers.image.source=https://github.com/goauthentik/authentik LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version=${VERSION} LABEL org.opencontainers.image.version ${VERSION}
LABEL org.opencontainers.image.revision=${GIT_BUILD_HASH} LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
WORKDIR / WORKDIR /

View File

@ -43,7 +43,7 @@ help: ## Show this help
sort sort
@echo "" @echo ""
go-test: test-go:
go test -timeout 0 -v -race -cover ./... go test -timeout 0 -v -race -cover ./...
test-docker: ## Run all tests in a docker-compose test-docker: ## Run all tests in a docker-compose
@ -210,9 +210,6 @@ web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting is
web-install: ## Install the necessary libraries to build the Authentik UI web-install: ## Install the necessary libraries to build the Authentik UI
cd web && npm ci cd web && npm ci
web-test: ## Run tests for the Authentik UI
cd web && npm run test
web-watch: ## Build and watch the Authentik UI for changes, updating automatically web-watch: ## Build and watch the Authentik UI for changes, updating automatically
rm -rf web/dist/ rm -rf web/dist/
mkdir web/dist/ mkdir web/dist/

View File

@ -15,9 +15,7 @@
## What is authentik? ## What is authentik?
authentik is an open-source Identity Provider that emphasizes flexibility and versatility, with support for a wide set of protocols. authentik is an open-source Identity Provider that emphasizes flexibility and versatility. It can be seamlessly integrated into existing environments to support new protocols. authentik is also a great solution for implementing sign-up, recovery, and other similar features in your application, saving you the hassle of dealing with them.
Our [enterprise offer](https://goauthentik.io/pricing) can also be used as a self-hosted replacement for large-scale deployments of Okta/Auth0, Entra ID, Ping Identity, or other legacy IdPs for employees and B2B2C use.
## Installation ## Installation

View File

@ -20,8 +20,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
| Version | Supported | | Version | Supported |
| -------- | --------- | | -------- | --------- |
| 2024.4.x | ✅ |
| 2024.6.x | ✅ | | 2024.6.x | ✅ |
| 2024.8.x | ✅ |
## Reporting a Vulnerability ## Reporting a Vulnerability

View File

@ -2,7 +2,7 @@
from os import environ from os import environ
__version__ = "2024.8.2" __version__ = "2024.6.3"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -1,20 +0,0 @@
"""authentik admin analytics"""
from typing import Any
from django.utils.translation import gettext_lazy as _
from authentik.root.celery import CELERY_APP
def get_analytics_description() -> dict[str, str]:
return {
"worker_count": _("Number of running workers"),
}
def get_analytics_data() -> dict[str, Any]:
worker_count = len(CELERY_APP.control.ping(timeout=0.5))
return {
"worker_count": worker_count,
}

View File

@ -73,7 +73,7 @@ class SystemInfoSerializer(PassiveSerializer):
"authentik_version": get_full_version(), "authentik_version": get_full_version(),
"environment": get_env(), "environment": get_env(),
"openssl_fips_enabled": ( "openssl_fips_enabled": (
backend._fips_enabled if LicenseKey.get_total().status().is_valid else None backend._fips_enabled if LicenseKey.get_total().is_valid() else None
), ),
"openssl_version": OPENSSL_VERSION, "openssl_version": OPENSSL_VERSION,
"platform": platform.platform(), "platform": platform.platform(),

View File

@ -12,7 +12,6 @@ from rest_framework.views import APIView
from authentik import __version__, get_build_hash from authentik import __version__, get_build_hash
from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.outposts.models import Outpost
class VersionSerializer(PassiveSerializer): class VersionSerializer(PassiveSerializer):
@ -23,7 +22,6 @@ class VersionSerializer(PassiveSerializer):
version_latest_valid = SerializerMethodField() version_latest_valid = SerializerMethodField()
build_hash = SerializerMethodField() build_hash = SerializerMethodField()
outdated = SerializerMethodField() outdated = SerializerMethodField()
outpost_outdated = SerializerMethodField()
def get_build_hash(self, _) -> str: def get_build_hash(self, _) -> str:
"""Get build hash, if version is not latest or released""" """Get build hash, if version is not latest or released"""
@ -49,15 +47,6 @@ class VersionSerializer(PassiveSerializer):
"""Check if we're running the latest version""" """Check if we're running the latest version"""
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance)) return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
def get_outpost_outdated(self, _) -> bool:
"""Check if any outpost is outdated/has a version mismatch"""
any_outdated = False
for outpost in Outpost.objects.all():
for state in outpost.state:
if state.version_outdated:
any_outdated = True
return any_outdated
class VersionView(APIView): class VersionView(APIView):
"""Get running and latest version.""" """Get running and latest version."""

View File

@ -1,8 +1,10 @@
"""authentik admin tasks""" """authentik admin tasks"""
import re
from django.core.cache import cache from django.core.cache import cache
from django.core.validators import URLValidator
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.translation import gettext_lazy as _
from packaging.version import parse from packaging.version import parse
from requests import RequestException from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -19,6 +21,8 @@ LOGGER = get_logger()
VERSION_NULL = "0.0.0" VERSION_NULL = "0.0.0"
VERSION_CACHE_KEY = "authentik_latest_version" VERSION_CACHE_KEY = "authentik_latest_version"
VERSION_CACHE_TIMEOUT = 8 * 60 * 60 # 8 hours VERSION_CACHE_TIMEOUT = 8 * 60 * 60 # 8 hours
# Chop of the first ^ because we want to search the entire string
URL_FINDER = URLValidator.regex.pattern[1:]
LOCAL_VERSION = parse(__version__) LOCAL_VERSION = parse(__version__)
@ -74,16 +78,10 @@ def update_latest_version(self: SystemTask):
context__new_version=upstream_version, context__new_version=upstream_version,
).exists(): ).exists():
return return
Event.new( event_dict = {"new_version": upstream_version}
EventAction.UPDATE_AVAILABLE, if match := re.search(URL_FINDER, data.get("stable", {}).get("changelog", "")):
message=_( event_dict["message"] = f"Changelog: {match.group()}"
"New version {version} available!".format( Event.new(EventAction.UPDATE_AVAILABLE, **event_dict).save()
version=upstream_version,
)
),
new_version=upstream_version,
changelog=data.get("stable", {}).get("changelog_url"),
).save()
except (RequestException, IndexError) as exc: except (RequestException, IndexError) as exc:
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.set_error(exc) self.set_error(exc)

View File

@ -17,7 +17,6 @@ RESPONSE_VALID = {
"stable": { "stable": {
"version": "99999999.9999999", "version": "99999999.9999999",
"changelog": "See https://goauthentik.io/test", "changelog": "See https://goauthentik.io/test",
"changelog_url": "https://goauthentik.io/test",
"reason": "bugfix", "reason": "bugfix",
}, },
} }
@ -36,7 +35,7 @@ class TestAdminTasks(TestCase):
Event.objects.filter( Event.objects.filter(
action=EventAction.UPDATE_AVAILABLE, action=EventAction.UPDATE_AVAILABLE,
context__new_version="99999999.9999999", context__new_version="99999999.9999999",
context__message="New version 99999999.9999999 available!", context__message="Changelog: https://goauthentik.io/test",
).exists() ).exists()
) )
# test that a consecutive check doesn't create a duplicate event # test that a consecutive check doesn't create a duplicate event
@ -46,7 +45,7 @@ class TestAdminTasks(TestCase):
Event.objects.filter( Event.objects.filter(
action=EventAction.UPDATE_AVAILABLE, action=EventAction.UPDATE_AVAILABLE,
context__new_version="99999999.9999999", context__new_version="99999999.9999999",
context__message="New version 99999999.9999999 available!", context__message="Changelog: https://goauthentik.io/test",
) )
), ),
1, 1,

View File

@ -1,54 +0,0 @@
"""authentik analytics api"""
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.fields import CharField, DictField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from authentik.analytics.utils import get_analytics_data, get_analytics_description
from authentik.core.api.utils import PassiveSerializer
from authentik.rbac.permissions import HasPermission
class AnalyticsDescriptionSerializer(PassiveSerializer):
label = CharField()
desc = CharField()
class AnalyticsDescriptionViewSet(ViewSet):
"""Read-only view of analytics descriptions"""
permission_classes = [HasPermission("authentik_rbac.view_system_settings")]
@extend_schema(responses={200: AnalyticsDescriptionSerializer})
def list(self, request: Request) -> Response:
"""Read-only view of analytics descriptions"""
data = []
for label, desc in get_analytics_description().items():
data.append({"label": label, "desc": desc})
return Response(AnalyticsDescriptionSerializer(data, many=True).data)
class AnalyticsDataViewSet(ViewSet):
"""Read-only view of analytics descriptions"""
permission_classes = [HasPermission("authentik_rbac.edit_system_settings")]
@extend_schema(
responses={
200: inline_serializer(
name="AnalyticsData",
fields={
"data": DictField(),
},
)
}
)
def list(self, request: Request) -> Response:
"""Read-only view of analytics descriptions"""
return Response(
{
"data": get_analytics_data(force=True),
}
)

View File

@ -1,12 +0,0 @@
"""authentik analytics app config"""
from authentik.blueprints.apps import ManagedAppConfig
class AuthentikAdminConfig(ManagedAppConfig):
"""authentik analytics app config"""
name = "authentik.analytics"
label = "authentik_analytics"
verbose_name = "authentik Analytics"
default = True

View File

@ -1,19 +0,0 @@
"""authentik analytics mixins"""
from typing import Any
from django.utils.translation import gettext_lazy as _
class AnalyticsMixin:
@classmethod
def get_analytics_description(cls) -> dict[str, str]:
object_name = _(cls._meta.verbose_name)
count_desc = _("Number of {object_name} objects".format_map({"object_name": object_name}))
return {
"count": count_desc,
}
@classmethod
def get_analytics_data(cls) -> dict[str, Any]:
return {"count": cls.objects.all().count()}

View File

@ -1,17 +0,0 @@
"""authentik admin settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"analytics_send": {
"task": "authentik.analytics.tasks.send_analytics",
"schedule": crontab(
minute=fqdn_rand("analytics_send"),
hour=fqdn_rand("analytics_send", stop=24),
day_of_week=fqdn_rand("analytics_send", 7),
),
"options": {"queue": "authentik_scheduled"},
}
}

View File

@ -1,45 +0,0 @@
"""authentik admin tasks"""
import orjson
from django.utils.translation import gettext_lazy as _
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.analytics.utils import get_analytics_data
from authentik.events.models import Event, EventAction
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP
from authentik.tenants.models import Tenant
LOGGER = get_logger()
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def send_analytics(self: SystemTask):
"""Send analytics"""
for tenant in Tenant.objects.filter(ready=True):
data = get_analytics_data(current_tenant=tenant)
if not tenant.analytics_enabled or not data:
self.set_status(TaskStatus.WARNING, "Analytics disabled. Nothing was sent.")
return
try:
response = get_http_session().post(
"https://customers.goauthentik.io/api/analytics/post/", json=data
)
response.raise_for_status()
self.set_status(
TaskStatus.SUCCESSFUL,
"Successfully sent analytics",
orjson.dumps(
data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS | orjson.OPT_UTC_Z
).decode(),
)
Event.new(
EventAction.ANALYTICS_SENT,
message=_("Analytics sent"),
analytics_data=data,
).save()
except (RequestException, IndexError) as exc:
self.set_error(exc)

View File

@ -1,76 +0,0 @@
"""authentik analytics tests"""
from json import loads
from requests_mock import Mocker
from django.test import TestCase
from django.urls import reverse
from authentik import __version__
from authentik.analytics.tasks import send_analytics
from authentik.analytics.utils import get_analytics_apps_data, get_analytics_apps_description, get_analytics_data, get_analytics_description, get_analytics_models_data, get_analytics_models_description
from authentik.core.models import Group, User
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.tenants.utils import get_current_tenant
class TestAnalytics(TestCase):
"""test analytics api"""
def setUp(self) -> None:
super().setUp()
self.user = User.objects.create(username=generate_id())
self.group = Group.objects.create(name=generate_id(), is_superuser=True)
self.group.users.add(self.user)
self.client.force_login(self.user)
self.tenant = get_current_tenant()
def test_description_api(self):
"""Test Version API"""
response = self.client.get(reverse("authentik_api:analytics-description-list"))
self.assertEqual(response.status_code, 200)
loads(response.content)
def test_data_api(self):
"""Test Version API"""
response = self.client.get(reverse("authentik_api:analytics-data-list"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["data"]["version"], __version__)
def test_sending_enabled(self):
"""Test analytics sending"""
self.tenant.analytics_enabled = True
self.tenant.save()
with Mocker() as mocker:
mocker.post("https://customers.goauthentik.io/api/analytics/post/", status_code=200)
send_analytics.delay().get()
self.assertTrue(
Event.objects.filter(
action=EventAction.ANALYTICS_SENT
).exists()
)
def test_sending_disabled(self):
"""Test analytics sending"""
self.tenant.analytics_enabled = False
self.tenant.save()
send_analytics.delay().get()
self.assertFalse(
Event.objects.filter(
action=EventAction.ANALYTICS_SENT
).exists()
)
def test_description_data_match_apps(self):
"""Test description and data keys match"""
description = get_analytics_apps_description()
data = get_analytics_apps_data()
self.assertEqual(data.keys(), description.keys())
def test_description_data_match_models(self):
"""Test description and data keys match"""
description = get_analytics_models_description()
data = get_analytics_models_data()
self.assertEqual(data.keys(), description.keys())

View File

@ -1,8 +0,0 @@
"""API URLs"""
from authentik.analytics.api import AnalyticsDataViewSet, AnalyticsDescriptionViewSet
api_urlpatterns = [
("analytics/description", AnalyticsDescriptionViewSet, "analytics-description"),
("analytics/data", AnalyticsDataViewSet, "analytics-data"),
]

View File

@ -1,112 +0,0 @@
"""authentik analytics utils"""
from hashlib import sha256
from importlib import import_module
from typing import Any
from structlog import get_logger
from authentik import get_full_version
from authentik.analytics.models import AnalyticsMixin
from authentik.lib.utils.reflection import get_apps
from authentik.root.install_id import get_install_id
from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
def get_analytics_apps() -> dict:
modules = {}
for _authentik_app in get_apps():
try:
module = import_module(f"{_authentik_app.name}.analytics")
except ModuleNotFoundError:
continue
except ImportError as exc:
LOGGER.warning(
"Could not import app's analytics", app_name=_authentik_app.name, exc=exc
)
continue
if not hasattr(module, "get_analytics_description") or not hasattr(
module, "get_analytics_data"
):
LOGGER.debug(
"App does not define API URLs",
app_name=_authentik_app.name,
)
continue
modules[_authentik_app.label] = module
return modules
def get_analytics_apps_description() -> dict[str, str]:
result = {}
for app_label, module in get_analytics_apps().items():
for k, v in module.get_analytics_description().items():
result[f"{app_label}/app/{k}"] = v
return result
def get_analytics_apps_data() -> dict[str, Any]:
result = {}
for app_label, module in get_analytics_apps().items():
for k, v in module.get_analytics_data().items():
result[f"{app_label}/app/{k}"] = v
return result
def get_analytics_models() -> list[AnalyticsMixin]:
def get_subclasses(cls):
for subclass in cls.__subclasses__():
if subclass.__subclasses__():
yield from get_subclasses(subclass)
elif not subclass._meta.abstract:
yield subclass
return list(get_subclasses(AnalyticsMixin))
def get_analytics_models_description() -> dict[str, str]:
result = {}
for model in get_analytics_models():
for k, v in model.get_analytics_description().items():
result[f"{model._meta.app_label}/models/{model._meta.object_name}/{k}"] = v
return result
def get_analytics_models_data() -> dict[str, Any]:
result = {}
for model in get_analytics_models():
for k, v in model.get_analytics_data().items():
result[f"{model._meta.app_label}/models/{model._meta.object_name}/{k}"] = v
return result
def get_analytics_description() -> dict[str, str]:
return {
**get_analytics_apps_description(),
**get_analytics_models_description(),
}
def get_analytics_data(current_tenant: Tenant | None = None, force: bool = False) -> dict[str, Any]:
current_tenant = current_tenant or get_current_tenant()
if not current_tenant.analytics_enabled and not force:
return {}
data = {
**get_analytics_apps_data(),
**get_analytics_models_data(),
}
to_remove = []
for key in data.keys():
if key not in current_tenant.analytics_sources:
to_remove.append(key)
for key in to_remove:
del data[key]
return {
**data,
"install_id_hash": sha256(get_install_id().encode()).hexdigest(),
"tenant_hash": sha256(current_tenant.tenant_uuid.bytes).hexdigest(),
"version": get_full_version(),
}

View File

@ -171,7 +171,7 @@ class Importer:
def default_context(self): def default_context(self):
"""Default context""" """Default context"""
return { return {
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid, "goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid(),
"goauthentik.io/rbac/models": rbac_models(), "goauthentik.io/rbac/models": rbac_models(),
} }

View File

@ -30,10 +30,8 @@ from authentik.core.api.utils import (
PassiveSerializer, PassiveSerializer,
) )
from authentik.core.expression.evaluator import PropertyMappingEvaluator from authentik.core.expression.evaluator import PropertyMappingEvaluator
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group, PropertyMapping, User from authentik.core.models import Group, PropertyMapping, User
from authentik.events.utils import sanitize_item from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.api.exec import PolicyTestSerializer from authentik.policies.api.exec import PolicyTestSerializer
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
@ -164,15 +162,12 @@ class PropertyMappingViewSet(
response_data = {"successful": True, "result": ""} response_data = {"successful": True, "result": ""}
try: try:
result = mapping.evaluate(dry_run=True, **context) result = mapping.evaluate(**context)
response_data["result"] = dumps( response_data["result"] = dumps(
sanitize_item(result), indent=(4 if format_result else None) sanitize_item(result), indent=(4 if format_result else None)
) )
except PropertyMappingExpressionException as exc:
response_data["result"] = exception_to_string(exc.exc)
response_data["successful"] = False
except Exception as exc: except Exception as exc:
response_data["result"] = exception_to_string(exc) response_data["result"] = str(exc)
response_data["successful"] = False response_data["successful"] = False
response = PropertyMappingTestResultSerializer(response_data) response = PropertyMappingTestResultSerializer(response_data)
return Response(response.data) return Response(response.data)

View File

@ -14,7 +14,6 @@ from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.rbac.filters import ObjectFilter
class DeleteAction(Enum): class DeleteAction(Enum):
@ -54,7 +53,7 @@ class UsedByMixin:
@extend_schema( @extend_schema(
responses={200: UsedBySerializer(many=True)}, responses={200: UsedBySerializer(many=True)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def used_by(self, request: Request, *args, **kwargs) -> Response: def used_by(self, request: Request, *args, **kwargs) -> Response:
"""Get a list of all objects that use this object""" """Get a list of all objects that use this object"""
model: Model = self.get_object() model: Model = self.get_object()

View File

@ -678,10 +678,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
if not request.tenant.impersonation: if not request.tenant.impersonation:
LOGGER.debug("User attempted to impersonate", user=request.user) LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401) return Response(status=401)
user_to_be = self.get_object() if not request.user.has_perm("impersonate"):
if not request.user.has_perm("impersonate", user_to_be):
LOGGER.debug("User attempted to impersonate without permissions", user=request.user) LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return Response(status=401) return Response(status=401)
user_to_be = self.get_object()
if user_to_be.pk == self.request.user.pk: if user_to_be.pk == self.request.user.pk:
LOGGER.debug("User attempted to impersonate themselves", user=request.user) LOGGER.debug("User attempted to impersonate themselves", user=request.user)
return Response(status=401) return Response(status=401)

View File

@ -9,11 +9,10 @@ class Command(TenantCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("--type", type=str, required=True) parser.add_argument("--type", type=str, required=True)
parser.add_argument("--all", action="store_true", default=False) parser.add_argument("--all", action="store_true")
parser.add_argument("usernames", nargs="*", type=str) parser.add_argument("usernames", nargs="+", type=str)
def handle_per_tenant(self, **options): def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"]) new_type = UserTypes(options["type"])
qs = ( qs = (
User.objects.exclude_anonymous() User.objects.exclude_anonymous()
@ -23,9 +22,6 @@ class Command(TenantCommand):
if options["usernames"] and options["all"]: if options["usernames"] and options["all"]:
self.stderr.write("--all and usernames specified, only one can be specified") self.stderr.write("--all and usernames specified, only one can be specified")
return return
if not options["usernames"] and not options["all"]:
self.stderr.write("--all or usernames must be specified")
return
if options["usernames"] and not options["all"]: if options["usernames"] and not options["all"]:
qs = qs.filter(username__in=options["usernames"]) qs = qs.filter(username__in=options["usernames"])
updated = qs.update(type=new_type) updated = qs.update(type=new_type)

View File

@ -23,7 +23,6 @@ from model_utils.managers import InheritanceManager
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.analytics.models import AnalyticsMixin
from authentik.blueprints.models import ManagedModel from authentik.blueprints.models import ManagedModel
from authentik.core.expression.exceptions import PropertyMappingExpressionException from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.types import UILoginButton, UserSettingSerializer from authentik.core.types import UILoginButton, UserSettingSerializer
@ -169,7 +168,7 @@ class GroupQuerySet(CTEQuerySet):
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte) return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
class Group(SerializerModel, AttributesMixin, AnalyticsMixin): class Group(SerializerModel, AttributesMixin):
"""Group model which supports a basic hierarchy and has attributes""" """Group model which supports a basic hierarchy and has attributes"""
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
@ -259,7 +258,7 @@ class UserManager(DjangoUserManager):
return self.get_queryset().exclude_anonymous() return self.get_queryset().exclude_anonymous()
class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser, AnalyticsMixin): class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model.""" """authentik User model, based on django's contrib auth user model."""
uuid = models.UUIDField(default=uuid4, editable=False, unique=True) uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
@ -377,7 +376,7 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser, An
return get_avatar(self) return get_avatar(self)
class Provider(SerializerModel, AnalyticsMixin): class Provider(SerializerModel):
"""Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application""" """Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application"""
name = models.TextField(unique=True) name = models.TextField(unique=True)
@ -467,11 +466,13 @@ class ApplicationQuerySet(QuerySet):
def with_provider(self) -> "QuerySet[Application]": def with_provider(self) -> "QuerySet[Application]":
qs = self.select_related("provider") qs = self.select_related("provider")
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
if LOOKUP_SEP in subclass:
continue
qs = qs.select_related(f"provider__{subclass}") qs = qs.select_related(f"provider__{subclass}")
return qs return qs
class Application(SerializerModel, PolicyBindingModel, AnalyticsMixin): class Application(SerializerModel, PolicyBindingModel):
"""Every Application which uses authentik for authentication/identification/authorization """Every Application which uses authentik for authentication/identification/authorization
needs an Application record. Other authentication types can subclass this Model to needs an Application record. Other authentication types can subclass this Model to
add custom fields and other properties""" add custom fields and other properties"""
@ -544,24 +545,15 @@ class Application(SerializerModel, PolicyBindingModel, AnalyticsMixin):
if not self.provider: if not self.provider:
return None return None
candidates = [] for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
base_class = Provider # We don't care about recursion, skip nested models
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class): if LOOKUP_SEP in subclass:
parent = self.provider
for level in subclass.split(LOOKUP_SEP):
try:
parent = getattr(parent, level)
except AttributeError:
break
if parent in candidates:
continue continue
idx = subclass.count(LOOKUP_SEP) try:
if type(parent) is not base_class: return getattr(self.provider, subclass)
idx += 1 except AttributeError:
candidates.insert(idx, parent) pass
if not candidates: return None
return None
return candidates[-1]
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)
@ -604,7 +596,7 @@ class SourceGroupMatchingModes(models.TextChoices):
) )
class Source(ManagedModel, SerializerModel, PolicyBindingModel, AnalyticsMixin): class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
name = models.TextField(help_text=_("Source's display Name.")) name = models.TextField(help_text=_("Source's display Name."))
@ -736,7 +728,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel, AnalyticsMixin):
] ]
class UserSourceConnection(SerializerModel, CreatedUpdatedModel, AnalyticsMixin): class UserSourceConnection(SerializerModel, CreatedUpdatedModel):
"""Connection between User and Source.""" """Connection between User and Source."""
user = models.ForeignKey(User, on_delete=models.CASCADE) user = models.ForeignKey(User, on_delete=models.CASCADE)
@ -756,7 +748,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel, AnalyticsMixin)
unique_together = (("user", "source"),) unique_together = (("user", "source"),)
class GroupSourceConnection(SerializerModel, CreatedUpdatedModel, AnalyticsMixin): class GroupSourceConnection(SerializerModel, CreatedUpdatedModel):
"""Connection between Group and Source.""" """Connection between Group and Source."""
group = models.ForeignKey(Group, on_delete=models.CASCADE) group = models.ForeignKey(Group, on_delete=models.CASCADE)
@ -880,7 +872,7 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
).save() ).save()
class PropertyMapping(SerializerModel, ManagedModel, AnalyticsMixin): class PropertyMapping(SerializerModel, ManagedModel):
"""User-defined key -> x mapping which can be used by providers to expose extra data.""" """User-defined key -> x mapping which can be used by providers to expose extra data."""
pm_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) pm_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
@ -909,7 +901,7 @@ class PropertyMapping(SerializerModel, ManagedModel, AnalyticsMixin):
except ControlFlowException as exc: except ControlFlowException as exc:
raise exc raise exc
except Exception as exc: except Exception as exc:
raise PropertyMappingExpressionException(exc, self) from exc raise PropertyMappingExpressionException(self, exc) from exc
def __str__(self): def __str__(self):
return f"Property Mapping {self.name}" return f"Property Mapping {self.name}"

View File

@ -9,12 +9,9 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import OAuth2Provider from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.proxy.models import ProxyProvider
from authentik.providers.saml.models import SAMLProvider
class TestApplicationsAPI(APITestCase): class TestApplicationsAPI(APITestCase):
@ -225,31 +222,3 @@ class TestApplicationsAPI(APITestCase):
], ],
}, },
) )
def test_get_provider(self):
"""Ensure that proxy providers (at the time of writing that is the only provider
that inherits from another proxy type (OAuth) instead of inheriting from the root
provider class) is correctly looked up and selected from the database"""
slug = generate_id()
provider = ProxyProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)
slug = generate_id()
provider = SAMLProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)

View File

@ -3,10 +3,10 @@
from json import loads from json import loads
from django.urls import reverse from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user, create_test_user from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.other_user = create_test_user() self.other_user = User.objects.create(username="to-impersonate")
self.user = create_test_admin_user() self.user = create_test_admin_user()
def test_impersonate_simple(self): def test_impersonate_simple(self):
@ -44,26 +44,6 @@ class TestImpersonation(APITestCase):
self.assertEqual(response_body["user"]["username"], self.user.username) self.assertEqual(response_body["user"]["username"], self.user.username)
self.assertNotIn("original", response_body) self.assertNotIn("original", response_body)
def test_impersonate_scoped(self):
"""Test impersonation with scoped permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user, self.other_user)
assign_perm("authentik_core.view_user", new_user, self.other_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_denied(self): def test_impersonate_denied(self):
"""test impersonation without permissions""" """test impersonation without permissions"""
self.client.force_login(self.other_user) self.client.force_login(self.other_user)

View File

@ -35,7 +35,6 @@ from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger() LOGGER = get_logger()
@ -266,7 +265,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
], ],
responses={200: CertificateDataSerializer(many=False)}, responses={200: CertificateDataSerializer(many=False)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def view_certificate(self, request: Request, pk: str) -> Response: def view_certificate(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs certificate and log access""" """Return certificate-key pairs certificate and log access"""
certificate: CertificateKeyPair = self.get_object() certificate: CertificateKeyPair = self.get_object()
@ -296,7 +295,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
], ],
responses={200: CertificateDataSerializer(many=False)}, responses={200: CertificateDataSerializer(many=False)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def view_private_key(self, request: Request, pk: str) -> Response: def view_private_key(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs private key and log access""" """Return certificate-key pairs private key and log access"""
certificate: CertificateKeyPair = self.get_object() certificate: CertificateKeyPair = self.get_object()

View File

@ -214,46 +214,6 @@ class TestCrypto(APITestCase):
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response) self.assertIn("Content-Disposition", response)
def test_certificate_download_denied(self):
"""Test certificate export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_private_key_download_denied(self):
"""Test private_key export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_used_by(self): def test_used_by(self):
"""Test used_by endpoint""" """Test used_by endpoint"""
self.client.force_login(create_test_admin_user()) self.client.force_login(create_test_admin_user())
@ -286,26 +246,6 @@ class TestCrypto(APITestCase):
], ],
) )
def test_used_by_denied(self):
"""Test used_by endpoint"""
self.client.logout()
keypair = create_test_cert()
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://localhost",
signing_key=keypair,
)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-used-by",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
def test_discovery(self): def test_discovery(self):
"""Test certificate discovery""" """Test certificate discovery"""
name = generate_id() name = generate_id()

View File

@ -1,11 +1,12 @@
"""Enterprise API Views""" """Enterprise API Views"""
from dataclasses import asdict
from datetime import timedelta from datetime import timedelta
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, inline_serializer from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, IntegerField from rest_framework.fields import CharField, IntegerField
@ -29,7 +30,7 @@ class EnterpriseRequiredMixin:
def validate(self, attrs: dict) -> dict: def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists""" """Check that a valid license exists"""
if not LicenseKey.cached_summary().status.is_valid: if not LicenseKey.cached_summary().has_license:
raise ValidationError(_("Enterprise is required to create/update this object.")) raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs) return super().validate(attrs)
@ -86,7 +87,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
}, },
) )
@action(detail=False, methods=["GET"]) @action(detail=False, methods=["GET"])
def install_id(self, request: Request) -> Response: def get_install_id(self, request: Request) -> Response:
"""Get install_id""" """Get install_id"""
return Response( return Response(
data={ data={
@ -99,22 +100,12 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
responses={ responses={
200: LicenseSummarySerializer(), 200: LicenseSummarySerializer(),
}, },
parameters=[
OpenApiParameter(
name="cached",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.BOOL,
default=True,
)
],
) )
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
def summary(self, request: Request) -> Response: def summary(self, request: Request) -> Response:
"""Get the total license status""" """Get the total license status"""
summary = LicenseKey.cached_summary() response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
if request.query_params.get("cached", "true").lower() == "false": response.is_valid(raise_exception=True)
summary = LicenseKey.get_total().summary()
response = LicenseSummarySerializer(instance=summary)
return Response(response.data) return Response(response.data)
@permission_required(None, ["authentik_enterprise.view_license"]) @permission_required(None, ["authentik_enterprise.view_license"])
@ -137,7 +128,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
forecast_for_months = 12 forecast_for_months = 12
response = LicenseForecastSerializer( response = LicenseForecastSerializer(
data={ data={
"internal_users": LicenseKey.get_internal_user_count(), "internal_users": LicenseKey.get_default_user_count(),
"external_users": LicenseKey.get_external_user_count(), "external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months), "forecasted_internal_users": (internal_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months), "forecasted_external_users": (external_in_last_month * forecast_for_months),

View File

@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
"""Actual enterprise check, cached""" """Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.cached_summary().status.is_valid return LicenseKey.cached_summary().valid

View File

@ -3,37 +3,24 @@
from base64 import b64decode from base64 import b64decode
from binascii import Error from binascii import Error
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from time import mktime from time import mktime
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import DaciteError, from_dict from dacite import from_dict
from django.core.cache import cache from django.core.cache import cache
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import ( from rest_framework.fields import BooleanField, DateTimeField, IntegerField
ChoiceField,
DateTimeField,
IntegerField,
ListField,
)
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes from authentik.core.models import User, UserTypes
from authentik.enterprise.models import ( from authentik.enterprise.models import License, LicenseUsage
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_EXPIRY_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.tenants.utils import get_unique_identifier from authentik.tenants.utils import get_unique_identifier
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
@ -55,9 +42,6 @@ def get_license_aud() -> str:
class LicenseFlags(Enum): class LicenseFlags(Enum):
"""License flags""" """License flags"""
TRIAL = "trial"
NON_PRODUCTION = "non_production"
@dataclass @dataclass
class LicenseSummary: class LicenseSummary:
@ -65,9 +49,12 @@ class LicenseSummary:
internal_users: int internal_users: int
external_users: int external_users: int
status: LicenseUsageStatus valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
latest_valid: datetime latest_valid: datetime
license_flags: list[LicenseFlags] has_license: bool
class LicenseSummarySerializer(PassiveSerializer): class LicenseSummarySerializer(PassiveSerializer):
@ -75,9 +62,12 @@ class LicenseSummarySerializer(PassiveSerializer):
internal_users = IntegerField(required=True) internal_users = IntegerField(required=True)
external_users = IntegerField(required=True) external_users = IntegerField(required=True)
status = ChoiceField(choices=LicenseUsageStatus.choices) valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField() latest_valid = DateTimeField()
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags))) has_license = BooleanField()
@dataclass @dataclass
@ -90,10 +80,10 @@ class LicenseKey:
name: str name: str
internal_users: int = 0 internal_users: int = 0
external_users: int = 0 external_users: int = 0
license_flags: list[LicenseFlags] = field(default_factory=list) flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod @staticmethod
def validate(jwt: str, check_expiry=True) -> "LicenseKey": def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT""" """Validate the license from a given JWT"""
try: try:
headers = get_unverified_header(jwt) headers = get_unverified_header(jwt)
@ -117,28 +107,26 @@ class LicenseKey:
our_cert.public_key(), our_cert.public_key(),
algorithms=["ES512"], algorithms=["ES512"],
audience=get_license_aud(), audience=get_license_aud(),
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
), ),
) )
except PyJWTError: except PyJWTError:
unverified = decode(jwt, options={"verify_signature": False})
if unverified["aud"] != get_license_aud():
raise ValidationError("Invalid Install ID in license") from None
raise ValidationError("Unable to verify license") from None raise ValidationError("Unable to verify license") from None
return body return body
@staticmethod @staticmethod
def get_total() -> "LicenseKey": def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses""" """Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in License.objects.all(): for lic in active_licenses:
total.internal_users += lic.internal_users total.internal_users += lic.internal_users
total.external_users += lic.external_users total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple())) exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0: if total.exp == 0:
total.exp = exp_ts total.exp = exp_ts
total.exp = max(total.exp, exp_ts) if exp_ts <= total.exp:
total.license_flags.extend(lic.status.license_flags) total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total return total
@staticmethod @staticmethod
@ -147,7 +135,7 @@ class LicenseKey:
return User.objects.all().exclude_anonymous().exclude(is_active=False) return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod @staticmethod
def get_internal_user_count(): def get_default_user_count():
"""Get current default user count""" """Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@ -156,73 +144,59 @@ class LicenseKey:
"""Get current external user count""" """Get current external user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count() return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()
def _last_valid_date(self): def is_valid(self) -> bool:
last_valid_date = ( """Check if the given license body covers all users
LicenseUsage.objects.order_by("-record_date")
.filter(status=LicenseUsageStatus.VALID)
.first()
)
if not last_valid_date:
return datetime.fromtimestamp(0, UTC)
return last_valid_date.record_date
def status(self) -> LicenseUsageStatus: Only checks the current count, no historical data is checked"""
"""Check if the given license body covers all users, and is valid.""" default_users = self.get_default_user_count()
last_valid = self._last_valid_date() if default_users > self.internal_users:
if self.exp == 0 and not License.objects.exists(): return False
return LicenseUsageStatus.UNLICENSED active_users = self.get_external_user_count()
_now = now() if active_users > self.external_users:
# Check limit-exceeded based status return False
internal_users = self.get_internal_user_count() return True
external_users = self.get_external_user_count()
if internal_users > self.internal_users or external_users > self.external_users:
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
return LicenseUsageStatus.READ_ONLY
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
# Check expiry based status
if datetime.fromtimestamp(self.exp, UTC) < _now:
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
weeks=THRESHOLD_READ_ONLY_WEEKS
):
return LicenseUsageStatus.READ_ONLY
return LicenseUsageStatus.EXPIRED
# Expiry warning
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
):
return LicenseUsageStatus.EXPIRY_SOON
return LicenseUsageStatus.VALID
def record_usage(self): def record_usage(self):
"""Capture the current validity status and metrics and save them""" """Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8) threshold = now() - timedelta(hours=8)
usage = ( if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first() LicenseUsage.objects.create(
) user_count=self.get_default_user_count(),
if not usage:
usage = LicenseUsage.objects.create(
internal_user_count=self.get_internal_user_count(),
external_user_count=self.get_external_user_count(), external_user_count=self.get_external_user_count(),
status=self.status(), within_limits=self.is_valid(),
) )
summary = asdict(self.summary()) summary = asdict(self.summary())
# Also cache the latest summary for the middleware # Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE) cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return usage return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
def summary(self) -> LicenseSummary: def summary(self) -> LicenseSummary:
"""Summary of license status""" """Summary of license status"""
status = self.status() has_license = License.objects.all().count() > 0
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(self.exp) latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary( return LicenseSummary(
show_admin_warning=show_admin_warning and has_license,
show_user_warning=show_user_warning and has_license,
read_only=read_only and has_license,
latest_valid=latest_valid, latest_valid=latest_valid,
internal_users=self.internal_users, internal_users=self.internal_users,
external_users=self.external_users, external_users=self.external_users,
status=status, valid=self.is_valid(),
license_flags=self.license_flags, has_license=has_license,
) )
@staticmethod @staticmethod
@ -231,8 +205,4 @@ class LicenseKey:
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE) summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary: if not summary:
return LicenseKey.get_total().summary() return LicenseKey.get_total().summary()
try: return from_dict(LicenseSummary, summary)
return from_dict(LicenseSummary, summary)
except DaciteError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()

View File

@ -8,7 +8,6 @@ from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsageStatus
from authentik.flows.views.executor import FlowExecutorView from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
@ -44,7 +43,7 @@ class EnterpriseMiddleware:
cached_status = LicenseKey.cached_summary() cached_status = LicenseKey.cached_summary()
if not cached_status: if not cached_status:
return True return True
if cached_status.status == LicenseUsageStatus.READ_ONLY: if cached_status.read_only:
return False return False
return True return True
@ -54,10 +53,10 @@ class EnterpriseMiddleware:
if request.method.lower() in ["get", "head", "options", "trace"]: if request.method.lower() in ["get", "head", "options", "trace"]:
return True return True
# Always allow requests to manage licenses # Always allow requests to manage licenses
if request.resolver_match._func_path == class_to_path(LicenseViewSet): if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
return True return True
# Flow executor is mounted as an API path but explicitly allowed # Flow executor is mounted as an API path but explicitly allowed
if request.resolver_match._func_path == class_to_path(FlowExecutorView): if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
return True return True
# Only apply these restrictions to the API # Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names: if "authentik_api" not in request.resolver_match.app_names:

View File

@ -1,68 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-08 14:15
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage")
db_alias = schema_editor.connection.alias
for usage in LicenseUsage.objects.using(db_alias).all():
usage.status = "valid" if usage.within_limits else "limit_exceeded_admin"
usage.save()
class Migration(migrations.Migration):
dependencies = [
("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"),
]
operations = [
migrations.AddField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
default=None,
null=True,
),
preserve_default=False,
),
migrations.RunPython(migrate_license_usage),
migrations.RemoveField(
model_name="licenseusage",
name="within_limits",
),
migrations.AlterField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
),
preserve_default=False,
),
migrations.RenameField(
model_name="licenseusage",
old_name="user_count",
new_name="internal_user_count",
),
]

View File

@ -17,17 +17,6 @@ if TYPE_CHECKING:
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
THRESHOLD_WARNING_ADMIN_WEEKS = 2
THRESHOLD_WARNING_USER_WEEKS = 4
THRESHOLD_WARNING_EXPIRY_WEEKS = 2
THRESHOLD_READ_ONLY_WEEKS = 6
class License(SerializerModel): class License(SerializerModel):
"""An authentik enterprise license""" """An authentik enterprise license"""
@ -50,7 +39,7 @@ class License(SerializerModel):
"""Get parsed license status""" """Get parsed license status"""
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key, check_expiry=False) return LicenseKey.validate(self.key)
class Meta: class Meta:
indexes = (HashIndex(fields=("key",)),) indexes = (HashIndex(fields=("key",)),)
@ -58,23 +47,9 @@ class License(SerializerModel):
verbose_name_plural = _("Licenses") verbose_name_plural = _("Licenses")
class LicenseUsageStatus(models.TextChoices): def usage_expiry():
"""License states an instance/tenant can be in""" """Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
UNLICENSED = "unlicensed"
VALID = "valid"
EXPIRED = "expired"
EXPIRY_SOON = "expiry_soon"
# User limit exceeded, 2 week threshold, show message in admin interface
LIMIT_EXCEEDED_ADMIN = "limit_exceeded_admin"
# User limit exceeded, 4 week threshold, show message in user interface
LIMIT_EXCEEDED_USER = "limit_exceeded_user"
READ_ONLY = "read_only"
@property
def is_valid(self) -> bool:
"""Quickly check if a license is valid"""
return self in [LicenseUsageStatus.VALID, LicenseUsageStatus.EXPIRY_SOON]
class LicenseUsage(ExpiringModel): class LicenseUsage(ExpiringModel):
@ -84,9 +59,9 @@ class LicenseUsage(ExpiringModel):
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
internal_user_count = models.BigIntegerField() user_count = models.BigIntegerField()
external_user_count = models.BigIntegerField() external_user_count = models.BigIntegerField()
status = models.TextField(choices=LicenseUsageStatus.choices) within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True) record_date = models.DateTimeField(auto_now_add=True)

View File

@ -13,7 +13,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
def check_license(self): def check_license(self):
"""Check license""" """Check license"""
if not LicenseKey.get_total().status().is_valid: if not LicenseKey.get_total().is_valid():
return PolicyResult(False, _("Enterprise required to access this feature.")) return PolicyResult(False, _("Enterprise required to access this feature."))
if self.request.user.type != UserTypes.INTERNAL: if self.request.user.type != UserTypes.INTERNAL:
return PolicyResult(False, _("Feature only accessible for internal users.")) return PolicyResult(False, _("Feature only accessible for internal users."))

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import ( from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
google_workspace_sync,
google_workspace_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -55,4 +52,3 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_single_task = google_workspace_sync sync_single_task = google_workspace_sync
sync_objects_task = google_workspace_sync_objects

View File

@ -181,7 +181,7 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-google-workspace-form" return "ak-property-mapping-google-workspace-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import ( from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
microsoft_entra_sync,
microsoft_entra_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -53,4 +50,3 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_single_task = microsoft_entra_sync sync_single_task = microsoft_entra_sync
sync_objects_task = microsoft_entra_sync_objects

View File

@ -170,7 +170,7 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-microsoft-entra-form" return "ak-property-mapping-microsoft-entra-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_rac", "0004_alter_connectiontoken_expires"),
]
operations = [
migrations.AlterModelOptions(
name="racpropertymapping",
options={
"verbose_name": "RAC Provider Property Mapping",
"verbose_name_plural": "RAC Provider Property Mappings",
},
),
]

View File

@ -125,7 +125,7 @@ class RACPropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-rac-form" return "ak-property-mapping-rac-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -136,8 +136,8 @@ class RACPropertyMapping(PropertyMapping):
return RACPropertyMappingSerializer return RACPropertyMappingSerializer
class Meta: class Meta:
verbose_name = _("RAC Provider Property Mapping") verbose_name = _("RAC Property Mapping")
verbose_name_plural = _("RAC Provider Property Mappings") verbose_name_plural = _("RAC Property Mappings")
class ConnectionToken(ExpiringModel): class ConnectionToken(ExpiringModel):

View File

@ -44,7 +44,7 @@ websocket_urlpatterns = [
api_urlpatterns = [ api_urlpatterns = [
("providers/rac", RACProviderViewSet), ("providers/rac", RACProviderViewSet),
("propertymappings/provider/rac", RACPropertyMappingViewSet), ("propertymappings/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet), ("rac/endpoints", EndpointViewSet),
("rac/connection_tokens", ConnectionTokenViewSet), ("rac/connection_tokens", ConnectionTokenViewSet),
] ]

View File

@ -3,7 +3,7 @@
from datetime import datetime from datetime import datetime
from django.core.cache import cache from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_save from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.timezone import get_current_timezone from django.utils.timezone import get_current_timezone
@ -27,9 +27,3 @@ def post_save_license(sender: type[License], instance: License, **_):
"""Trigger license usage calculation when license is saved""" """Trigger license usage calculation when license is saved"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
enterprise_update_usage.delay() enterprise_update_usage.delay()
@receiver(post_delete, sender=License)
def post_delete_license(sender: type[License], instance: License, **_):
"""Clear license cache when license is deleted"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)

View File

@ -9,26 +9,10 @@ from django.utils.timezone import now
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import ( from authentik.enterprise.models import License
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
# Valid license expiry _exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry, expires soon
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
# Invalid license expiry, recently expired
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
# Invalid license expiry, expired longer ago
expiry_expired_read_only = int(
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
)
class TestEnterpriseLicense(TestCase): class TestEnterpriseLicense(TestCase):
@ -39,7 +23,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
exp=expiry_valid, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, internal_users=100,
external_users=100, external_users=100,
@ -49,7 +33,7 @@ class TestEnterpriseLicense(TestCase):
def test_valid(self): def test_valid(self):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid) self.assertTrue(lic.status.is_valid())
self.assertEqual(lic.internal_users, 100) self.assertEqual(lic.internal_users, 100)
def test_invalid(self): def test_invalid(self):
@ -62,7 +46,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
exp=expiry_valid, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, internal_users=100,
external_users=100, external_users=100,
@ -72,186 +56,11 @@ class TestEnterpriseLicense(TestCase):
def test_valid_multiple(self): def test_valid_multiple(self):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid) self.assertTrue(lic.status.is_valid())
lic2 = License.objects.create(key=generate_id()) lic2 = License.objects.create(key=generate_id())
self.assertTrue(lic2.status.status().is_valid) self.assertTrue(lic2.status.is_valid())
total = LicenseKey.get_total() total = LicenseKey.get_total()
self.assertEqual(total.internal_users, 200) self.assertEqual(total.internal_users, 200)
self.assertEqual(total.external_users, 200) self.assertEqual(total.external_users, 200)
self.assertEqual(total.exp, expiry_valid) self.assertEqual(total.exp, _exp)
self.assertTrue(total.status().is_valid) self.assertTrue(total.is_valid())
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_user_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_admin_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired_read_only,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_expired(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_soon,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_soon(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON)

View File

@ -1,217 +0,0 @@
"""read only tests"""
from datetime import timedelta
from unittest.mock import MagicMock, patch
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.flows.models import (
FlowDesignation,
FlowStageBinding,
)
from authentik.flows.tests import FlowTestCase
from authentik.lib.generators import generate_id
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.password import BACKEND_INBUILT
from authentik.stages.password.models import PasswordStage
from authentik.stages.user_login.models import UserLoginStage
class TestReadOnly(FlowTestCase):
"""Test read_only"""
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_login(self):
"""Test flow, ensure login is still possible with read only mode"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
flow = create_test_flow(
FlowDesignation.AUTHENTICATION,
)
ident_stage = IdentificationStage.objects.create(
name=generate_id(),
user_fields=[UserFields.E_MAIL],
pretend_user_exists=False,
)
FlowStageBinding.objects.create(
target=flow,
stage=ident_stage,
order=0,
)
password_stage = PasswordStage.objects.create(
name=generate_id(), backends=[BACKEND_INBUILT]
)
FlowStageBinding.objects.create(
target=flow,
stage=password_stage,
order=1,
)
login_stage = UserLoginStage.objects.create(
name=generate_id(),
)
FlowStageBinding.objects.create(
target=flow,
stage=login_stage,
order=2,
)
user = create_test_user()
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(exec_url)
self.assertStageResponse(
response,
flow,
component="ak-stage-identification",
password_fields=False,
primary_action="Log in",
sources=[],
show_source_labels=False,
user_fields=[UserFields.E_MAIL],
)
response = self.client.post(exec_url, {"uid_field": user.email}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-password")
response = self.client.post(exec_url, {"password": user.username}, follow=True)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_licenses(self):
"""Test that managing licenses is still possible"""
license = License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Reading is always allowed
response = self.client.get(reverse("authentik_api:license-list"))
self.assertEqual(response.status_code, 200)
# Writing should also be allowed
response = self.client.patch(
reverse("authentik_api:license-detail", kwargs={"pk": license.pk})
)
self.assertEqual(response.status_code, 200)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_flows(self):
"""Test flow"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Read only is still allowed
response = self.client.get(reverse("authentik_api:flow-list"))
self.assertEqual(response.status_code, 200)
flow = create_test_flow()
# Writing is not
response = self.client.patch(
reverse("authentik_api:flow-detail", kwargs={"slug": flow.slug})
)
self.assertJSONEqual(
response.content,
{"detail": "Request denied due to expired/invalid license.", "code": "denied_license"},
)
self.assertEqual(response.status_code, 400)

View File

@ -69,5 +69,8 @@ class NotificationViewSet(
@action(detail=False, methods=["post"]) @action(detail=False, methods=["post"])
def mark_all_seen(self, request: Request) -> Response: def mark_all_seen(self, request: Request) -> Response:
"""Mark all the user's notifications as seen""" """Mark all the user's notifications as seen"""
Notification.objects.filter(user=request.user, seen=False).update(seen=True) notifications = Notification.objects.filter(user=request.user)
for notification in notifications:
notification.seen = True
Notification.objects.bulk_update(notifications, ["seen"])
return Response({}, status=204) return Response({}, status=204)

View File

@ -1,49 +0,0 @@
# Generated by Django 5.0.9 on 2024-09-25 11:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0007_event_authentik_e_action_9a9dd9_idx_and_more"),
]
operations = [
migrations.AlterField(
model_name="event",
name="action",
field=models.TextField(
choices=[
("login", "Login"),
("login_failed", "Login Failed"),
("logout", "Logout"),
("user_write", "User Write"),
("suspicious_request", "Suspicious Request"),
("password_set", "Password Set"),
("secret_view", "Secret View"),
("secret_rotate", "Secret Rotate"),
("invitation_used", "Invite Used"),
("authorize_application", "Authorize Application"),
("source_linked", "Source Linked"),
("impersonation_started", "Impersonation Started"),
("impersonation_ended", "Impersonation Ended"),
("flow_execution", "Flow Execution"),
("policy_execution", "Policy Execution"),
("policy_exception", "Policy Exception"),
("property_mapping_exception", "Property Mapping Exception"),
("system_task_execution", "System Task Execution"),
("system_task_exception", "System Task Exception"),
("system_exception", "System Exception"),
("configuration_error", "Configuration Error"),
("model_created", "Model Created"),
("model_updated", "Model Updated"),
("model_deleted", "Model Deleted"),
("email_sent", "Email Sent"),
("analytics_sent", "Analytics Sent"),
("update_available", "Update Available"),
("custom_", "Custom Prefix"),
]
),
),
]

View File

@ -49,7 +49,6 @@ from authentik.policies.models import PolicyBindingModel
from authentik.root.middleware import ClientIPMiddleware from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger() LOGGER = get_logger()
DISCORD_FIELD_LIMIT = 25 DISCORD_FIELD_LIMIT = 25
@ -59,11 +58,7 @@ NOTIFICATION_SUMMARY_LENGTH = 75
def default_event_duration(): def default_event_duration():
"""Default duration an Event is saved. """Default duration an Event is saved.
This is used as a fallback when no brand is available""" This is used as a fallback when no brand is available"""
try: return now() + timedelta(days=365)
tenant = get_current_tenant()
return now() + timedelta_from_string(tenant.event_retention)
except Tenant.DoesNotExist:
return now() + timedelta(days=365)
def default_brand(): def default_brand():
@ -119,7 +114,6 @@ class EventAction(models.TextChoices):
MODEL_DELETED = "model_deleted" MODEL_DELETED = "model_deleted"
EMAIL_SENT = "email_sent" EMAIL_SENT = "email_sent"
ANALYTICS_SENT = "analytics_sent"
UPDATE_AVAILABLE = "update_available" UPDATE_AVAILABLE = "update_available"
CUSTOM_PREFIX = "custom_" CUSTOM_PREFIX = "custom_"
@ -251,6 +245,12 @@ class Event(SerializerModel, ExpiringModel):
if QS_QUERY in self.context["http_request"]["args"]: if QS_QUERY in self.context["http_request"]["args"]:
wrapped = self.context["http_request"]["args"][QS_QUERY] wrapped = self.context["http_request"]["args"][QS_QUERY]
self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped)) self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
if hasattr(request, "tenant"):
tenant: Tenant = request.tenant
# Because self.created only gets set on save, we can't use it's value here
# hence we set self.created to now and then use it
self.created = now()
self.expires = self.created + timedelta_from_string(tenant.event_retention)
if hasattr(request, "brand"): if hasattr(request, "brand"):
brand: Brand = request.brand brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand)) self.brand = sanitize_dict(model_to_dict(brand))

View File

@ -13,7 +13,7 @@ from authentik.events.apps import SYSTEM_TASK_STATUS
from authentik.events.models import Event, EventAction, SystemTask from authentik.events.models import Event, EventAction, SystemTask
from authentik.events.tasks import event_notification_handler, gdpr_cleanup from authentik.events.tasks import event_notification_handler, gdpr_cleanup
from authentik.flows.models import Stage from authentik.flows.models import Stage
from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.root.monitoring import monitoring_set from authentik.root.monitoring import monitoring_set
from authentik.stages.invitation.models import Invitation from authentik.stages.invitation.models import Invitation
@ -38,9 +38,6 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
# Save the login method used # Save the login method used
kwargs[PLAN_CONTEXT_METHOD] = flow_plan.context[PLAN_CONTEXT_METHOD] kwargs[PLAN_CONTEXT_METHOD] = flow_plan.context[PLAN_CONTEXT_METHOD]
kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
if PLAN_CONTEXT_OUTPOST in flow_plan.context:
# Save outpost context
kwargs[PLAN_CONTEXT_OUTPOST] = flow_plan.context[PLAN_CONTEXT_OUTPOST]
event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user) event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user)
request.session[SESSION_LOGIN_EVENT] = event request.session[SESSION_LOGIN_EVENT] = event

View File

@ -6,7 +6,6 @@ from django.db.models import Model
from django.test import TestCase from django.test import TestCase
from authentik.core.models import default_token_key from authentik.core.models import default_token_key
from authentik.events.models import default_event_duration
from authentik.lib.utils.reflection import get_apps from authentik.lib.utils.reflection import get_apps
@ -21,7 +20,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable:
allowed = 0 allowed = 0
# Token-like objects need to lookup the current tenant to get the default token length # Token-like objects need to lookup the current tenant to get the default token length
for field in test_model._meta.fields: for field in test_model._meta.fields:
if field.default in [default_token_key, default_event_duration]: if field.default == default_token_key:
allowed += 1 allowed += 1
with self.assertNumQueries(allowed): with self.assertNumQueries(allowed):
str(test_model()) str(test_model())

View File

@ -2,8 +2,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from django.urls import reverse from django.test import TestCase
from rest_framework.test import APITestCase
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.models import ( from authentik.events.models import (
@ -11,7 +10,6 @@ from authentik.events.models import (
EventAction, EventAction,
Notification, Notification,
NotificationRule, NotificationRule,
NotificationSeverity,
NotificationTransport, NotificationTransport,
NotificationWebhookMapping, NotificationWebhookMapping,
TransportMode, TransportMode,
@ -22,7 +20,7 @@ from authentik.policies.exceptions import PolicyException
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
class TestEventsNotifications(APITestCase): class TestEventsNotifications(TestCase):
"""Test Event Notifications""" """Test Event Notifications"""
def setUp(self) -> None: def setUp(self) -> None:
@ -133,15 +131,3 @@ class TestEventsNotifications(APITestCase):
Notification.objects.all().delete() Notification.objects.all().delete()
Event.new(EventAction.CUSTOM_PREFIX).save() Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(Notification.objects.first().body, "foo") self.assertEqual(Notification.objects.first().body, "foo")
def test_api_mark_all_seen(self):
"""Test mark_all_seen"""
self.client.force_login(self.user)
Notification.objects.create(
severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
)
response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
self.assertEqual(response.status_code, 204)
self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())

View File

@ -37,7 +37,6 @@ from authentik.lib.utils.file import (
) )
from authentik.lib.views import bad_request_message from authentik.lib.views import bad_request_message
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger() LOGGER = get_logger()
@ -282,7 +281,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
400: OpenApiResponse(description="Flow not applicable"), 400: OpenApiResponse(description="Flow not applicable"),
}, },
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def execute(self, request: Request, slug: str): def execute(self, request: Request, slug: str):
"""Execute flow for current user""" """Execute flow for current user"""
# Because we pre-plan the flow here, and not in the planner, we need to manually clear # Because we pre-plan the flow here, and not in the planner, we need to manually clear

View File

@ -23,7 +23,6 @@ from authentik.flows.models import (
in_memory_stage, in_memory_stage,
) )
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.outposts.models import Outpost
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.root.middleware import ClientIPMiddleware from authentik.root.middleware import ClientIPMiddleware
@ -33,7 +32,6 @@ PLAN_CONTEXT_SSO = "is_sso"
PLAN_CONTEXT_REDIRECT = "redirect" PLAN_CONTEXT_REDIRECT = "redirect"
PLAN_CONTEXT_APPLICATION = "application" PLAN_CONTEXT_APPLICATION = "application"
PLAN_CONTEXT_SOURCE = "source" PLAN_CONTEXT_SOURCE = "source"
PLAN_CONTEXT_OUTPOST = "outpost"
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored. # was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored" PLAN_CONTEXT_IS_RESTORED = "is_restored"
@ -145,23 +143,10 @@ class FlowPlanner:
and not request.user.is_superuser and not request.user.is_superuser
): ):
raise FlowNonApplicableException() raise FlowNonApplicableException()
outpost_user = ClientIPMiddleware.get_outpost_user(request)
if self.flow.authentication == FlowAuthenticationRequirement.REQUIRE_OUTPOST: if self.flow.authentication == FlowAuthenticationRequirement.REQUIRE_OUTPOST:
outpost_user = ClientIPMiddleware.get_outpost_user(request)
if not outpost_user: if not outpost_user:
raise FlowNonApplicableException() raise FlowNonApplicableException()
if outpost_user:
outpost = Outpost.objects.filter(
# TODO: Since Outpost and user are not directly connected, we have to look up a user
# like this. This should ideally by in authentik/outposts/models.py
pk=outpost_user.username.replace("ak-outpost-", "")
).first()
if outpost:
return {
PLAN_CONTEXT_OUTPOST: {
"instance": outpost,
}
}
return {}
def plan(self, request: HttpRequest, default_context: dict[str, Any] | None = None) -> FlowPlan: def plan(self, request: HttpRequest, default_context: dict[str, Any] | None = None) -> FlowPlan:
"""Check each of the flows' policies, check policies for each stage with PolicyBinding """Check each of the flows' policies, check policies for each stage with PolicyBinding
@ -174,12 +159,11 @@ class FlowPlanner:
self._logger.debug( self._logger.debug(
"f(plan): starting planning process", "f(plan): starting planning process",
) )
context = default_context or {}
# Bit of a workaround here, if there is a pending user set in the default context # Bit of a workaround here, if there is a pending user set in the default context
# we use that user for our cache key # we use that user for our cache key
# to make sure they don't get the generic response # to make sure they don't get the generic response
if context and PLAN_CONTEXT_PENDING_USER in context: if default_context and PLAN_CONTEXT_PENDING_USER in default_context:
user = context[PLAN_CONTEXT_PENDING_USER] user = default_context[PLAN_CONTEXT_PENDING_USER]
else: else:
user = request.user user = request.user
# We only need to check the flow authentication if it's planned without a user # We only need to check the flow authentication if it's planned without a user
@ -187,13 +171,14 @@ class FlowPlanner:
# or if a flow is restarted due to `invalid_response_action` being set to # or if a flow is restarted due to `invalid_response_action` being set to
# `restart_with_context`, which can only happen if the user was already authorized # `restart_with_context`, which can only happen if the user was already authorized
# to use the flow # to use the flow
context.update(self._check_authentication(request)) self._check_authentication(request)
# First off, check the flow's direct policy bindings # First off, check the flow's direct policy bindings
# to make sure the user even has access to the flow # to make sure the user even has access to the flow
engine = PolicyEngine(self.flow, user, request) engine = PolicyEngine(self.flow, user, request)
engine.use_cache = self.use_cache engine.use_cache = self.use_cache
span.set_data("context", cleanse_dict(context)) if default_context:
engine.request.context.update(context) span.set_data("default_context", cleanse_dict(default_context))
engine.request.context.update(default_context)
engine.build() engine.build()
result = engine.result result = engine.result
if not result.passing: if not result.passing:
@ -210,12 +195,12 @@ class FlowPlanner:
key=cached_plan_key, key=cached_plan_key,
) )
# Reset the context as this isn't factored into caching # Reset the context as this isn't factored into caching
cached_plan.context = context cached_plan.context = default_context or {}
return cached_plan return cached_plan
self._logger.debug( self._logger.debug(
"f(plan): building plan", "f(plan): building plan",
) )
plan = self._build_plan(user, request, context) plan = self._build_plan(user, request, default_context)
if self.use_cache: if self.use_cache:
cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT) cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT)
if not plan.bindings and not self.allow_empty_flows: if not plan.bindings and not self.allow_empty_flows:

View File

@ -2,6 +2,7 @@
import re import re
import socket import socket
from collections.abc import Iterable
from ipaddress import ip_address, ip_network from ipaddress import ip_address, ip_network
from textwrap import indent from textwrap import indent
from types import CodeType from types import CodeType
@ -27,12 +28,6 @@ from authentik.stages.authenticator import devices_for_user
LOGGER = get_logger() LOGGER = get_logger()
ARG_SANITIZE = re.compile(r"[:.-]")
def sanitize_arg(arg_name: str) -> str:
return re.sub(ARG_SANITIZE, "_", arg_name)
class BaseEvaluator: class BaseEvaluator:
"""Validate and evaluate python-based expressions""" """Validate and evaluate python-based expressions"""
@ -182,9 +177,9 @@ class BaseEvaluator:
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
return proc.profiling_wrapper() return proc.profiling_wrapper()
def wrap_expression(self, expression: str) -> str: def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
"""Wrap expression in a function, call it, and save the result as `result`""" """Wrap expression in a function, call it, and save the result as `result`"""
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys()) handler_signature = ",".join(params)
full_expression = "" full_expression = ""
full_expression += f"def handler({handler_signature}):\n" full_expression += f"def handler({handler_signature}):\n"
full_expression += indent(expression, " ") full_expression += indent(expression, " ")
@ -193,8 +188,8 @@ class BaseEvaluator:
def compile(self, expression: str) -> CodeType: def compile(self, expression: str) -> CodeType:
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect.""" """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
expression = self.wrap_expression(expression) param_keys = self._context.keys()
return compile(expression, self._filename, "exec") return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
def evaluate(self, expression_source: str) -> Any: def evaluate(self, expression_source: str) -> Any:
"""Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
@ -210,7 +205,7 @@ class BaseEvaluator:
self.handle_error(exc, expression_source) self.handle_error(exc, expression_source)
raise exc raise exc
try: try:
_locals = {sanitize_arg(x): y for x, y in self._context.items()} _locals = self._context
# Yes this is an exec, yes it is potentially bad. Since we limit what variables are # Yes this is an exec, yes it is potentially bad. Since we limit what variables are
# available here, and these policies can only be edited by admins, this is a risk # available here, and these policies can only be edited by admins, this is a risk
# we're willing to take. # we're willing to take.

View File

@ -1,19 +1,16 @@
from celery import Task from collections.abc import Callable
from django.utils.text import slugify from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ChoiceField from rest_framework.fields import BooleanField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Group, User
from authentik.events.api.tasks import SystemTaskSerializer from authentik.events.api.tasks import SystemTaskSerializer
from authentik.events.logs import LogEvent, LogEventSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
from authentik.rbac.filters import ObjectFilter
class SyncStatusSerializer(PassiveSerializer): class SyncStatusSerializer(PassiveSerializer):
@ -23,29 +20,10 @@ class SyncStatusSerializer(PassiveSerializer):
tasks = SystemTaskSerializer(many=True, read_only=True) tasks = SystemTaskSerializer(many=True, read_only=True)
class SyncObjectSerializer(PassiveSerializer):
"""Sync object serializer"""
sync_object_model = ChoiceField(
choices=(
(class_to_path(User), "user"),
(class_to_path(Group), "group"),
)
)
sync_object_id = CharField()
class SyncObjectResultSerializer(PassiveSerializer):
"""Result of a single object sync"""
messages = LogEventSerializer(many=True, read_only=True)
class OutgoingSyncProviderStatusMixin: class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers""" """Common API Endpoints for Outgoing sync providers"""
sync_single_task: type[Task] = None sync_single_task: Callable = None
sync_objects_task: type[Task] = None
@extend_schema( @extend_schema(
responses={ responses={
@ -58,7 +36,7 @@ class OutgoingSyncProviderStatusMixin:
detail=True, detail=True,
pagination_class=None, pagination_class=None,
url_path="sync/status", url_path="sync/status",
filter_backends=[ObjectFilter], filter_backends=[],
) )
def sync_status(self, request: Request, pk: int) -> Response: def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status""" """Get provider's sync status"""
@ -77,30 +55,6 @@ class OutgoingSyncProviderStatusMixin:
} }
return Response(SyncStatusSerializer(status).data) return Response(SyncStatusSerializer(status).data)
@extend_schema(
request=SyncObjectSerializer,
responses={200: SyncObjectResultSerializer()},
)
@action(
methods=["POST"],
detail=True,
pagination_class=None,
url_path="sync/object",
filter_backends=[ObjectFilter],
)
def sync_object(self, request: Request, pk: int) -> Response:
"""Sync/Re-sync a single user/group object"""
provider: OutgoingSyncProvider = self.get_object()
params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True)
res: list[LogEvent] = self.sync_objects_task.delay(
params.validated_data["sync_object_model"],
page=1,
provider_pk=provider.pk,
pk=params.validated_data["sync_object_id"],
).get()
return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
class OutgoingSyncConnectionCreateMixin: class OutgoingSyncConnectionCreateMixin:
"""Mixin for connection objects that fetches remote data upon creation""" """Mixin for connection objects that fetches remote data upon creation"""

View File

@ -105,7 +105,7 @@ class SyncTasks:
return return
task.set_status(TaskStatus.SUCCESSFUL, *messages) task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects(self, object_type: str, page: int, provider_pk: int, **filter): def sync_objects(self, object_type: str, page: int, provider_pk: int):
_object_type = path_to_class(object_type) _object_type = path_to_class(object_type)
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
@ -120,7 +120,7 @@ class SyncTasks:
client = provider.client_for_model(_object_type) client = provider.client_for_model(_object_type)
except TransientSyncException: except TransientSyncException:
return messages return messages
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
if client.can_discover: if client.can_discover:
self.logger.debug("starting discover") self.logger.debug("starting discover")
client.discover() client.discover()

View File

@ -26,6 +26,7 @@ from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
from authentik.outposts.models import ( from authentik.outposts.models import (
Outpost, Outpost,
OutpostConfig, OutpostConfig,
OutpostState,
OutpostType, OutpostType,
default_outpost_config, default_outpost_config,
) )
@ -139,7 +140,7 @@ class OutpostHealthSerializer(PassiveSerializer):
def get_fips_enabled(self, obj: dict) -> bool | None: def get_fips_enabled(self, obj: dict) -> bool | None:
"""Get FIPS enabled""" """Get FIPS enabled"""
if not LicenseKey.get_total().status().is_valid: if not LicenseKey.get_total().is_valid():
return None return None
return obj["fips_enabled"] return obj["fips_enabled"]
@ -181,6 +182,7 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
outpost: Outpost = self.get_object() outpost: Outpost = self.get_object()
states = [] states = []
for state in outpost.state: for state in outpost.state:
state: OutpostState
states.append( states.append(
{ {
"uid": state.uid, "uid": state.uid,

View File

@ -26,7 +26,6 @@ from authentik.outposts.models import (
KubernetesServiceConnection, KubernetesServiceConnection,
OutpostServiceConnection, OutpostServiceConnection,
) )
from authentik.rbac.filters import ObjectFilter
class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer): class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer):
@ -76,7 +75,7 @@ class ServiceConnectionViewSet(
filterset_fields = ["name"] filterset_fields = ["name"]
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)}) @extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def state(self, request: Request, pk: str) -> Response: def state(self, request: Request, pk: str) -> Response:
"""Get the service connection's state""" """Get the service connection's state"""
connection = self.get_object() connection = self.get_object()

View File

@ -451,7 +451,7 @@ class OutpostState:
return False return False
if self.build_hash != get_build_hash(): if self.build_hash != get_build_hash():
return False return False
return parse(self.version) != OUR_VERSION return parse(self.version) < OUR_VERSION
@staticmethod @staticmethod
def for_outpost(outpost: Outpost) -> list["OutpostState"]: def for_outpost(outpost: Outpost) -> list["OutpostState"]:

View File

@ -214,7 +214,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
if not hasattr(instance, field_name): if not hasattr(instance, field_name):
continue continue
LOGGER.debug("triggering outpost update from field", field=field.name) LOGGER.debug("triggering outpost update from from field", field=field.name)
# Because the Outpost Model has an M2M to Provider, # Because the Outpost Model has an M2M to Provider,
# we have to iterate over the entire QS # we have to iterate over the entire QS
for reverse in getattr(instance, field_name).all(): for reverse in getattr(instance, field_name).all():

View File

@ -1,52 +0,0 @@
# Generated by Django 5.0.9 on 2024-09-25 11:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_event_matcher", "0023_alter_eventmatcherpolicy_action_and_more"),
]
operations = [
migrations.AlterField(
model_name="eventmatcherpolicy",
name="action",
field=models.TextField(
choices=[
("login", "Login"),
("login_failed", "Login Failed"),
("logout", "Logout"),
("user_write", "User Write"),
("suspicious_request", "Suspicious Request"),
("password_set", "Password Set"),
("secret_view", "Secret View"),
("secret_rotate", "Secret Rotate"),
("invitation_used", "Invite Used"),
("authorize_application", "Authorize Application"),
("source_linked", "Source Linked"),
("impersonation_started", "Impersonation Started"),
("impersonation_ended", "Impersonation Ended"),
("flow_execution", "Flow Execution"),
("policy_execution", "Policy Execution"),
("policy_exception", "Policy Exception"),
("property_mapping_exception", "Property Mapping Exception"),
("system_task_execution", "System Task Execution"),
("system_task_exception", "System Task Exception"),
("system_exception", "System Exception"),
("configuration_error", "Configuration Error"),
("model_created", "Model Created"),
("model_updated", "Model Updated"),
("model_deleted", "Model Deleted"),
("email_sent", "Email Sent"),
("analytics_sent", "Analytics Sent"),
("update_available", "Update Available"),
("custom_", "Custom Prefix"),
],
default=None,
help_text="Match created events with this action type. When left empty, all action types will be matched.",
null=True,
),
),
]

View File

@ -36,7 +36,7 @@ def update_score(request: HttpRequest, identifier: str, amount: int):
if not created: if not created:
reputation.score = F("score") + amount reputation.score = F("score") + amount
reputation.save() reputation.save()
LOGGER.info("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip) LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
@receiver(login_failed) @receiver(login_failed)

View File

@ -2,25 +2,15 @@
from django.db.models import QuerySet from django.db.models import QuerySet
from django.db.models.query import Q from django.db.models.query import Q
from django.shortcuts import get_object_or_404
from django_filters.filters import BooleanFilter from django_filters.filters import BooleanFilter
from django_filters.filterset import FilterSet from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes from rest_framework.fields import CharField, ListField, SerializerMethodField
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ListField, SerializerMethodField
from rest_framework.mixins import ListModelMixin from rest_framework.mixins import ListModelMixin
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.core.api.providers import ProviderSerializer from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.api.utils import ModelSerializer
from authentik.core.models import Application
from authentik.policies.api.exec import PolicyTestResultSerializer
from authentik.policies.engine import PolicyEngine
from authentik.policies.types import PolicyResult
from authentik.providers.ldap.models import LDAPProvider from authentik.providers.ldap.models import LDAPProvider
@ -33,6 +23,7 @@ class LDAPProviderSerializer(ProviderSerializer):
model = LDAPProvider model = LDAPProvider
fields = ProviderSerializer.Meta.fields + [ fields = ProviderSerializer.Meta.fields + [
"base_dn", "base_dn",
"search_group",
"certificate", "certificate",
"tls_server_name", "tls_server_name",
"uid_start_number", "uid_start_number",
@ -64,6 +55,8 @@ class LDAPProviderFilter(FilterSet):
"name": ["iexact"], "name": ["iexact"],
"authorization_flow__slug": ["iexact"], "authorization_flow__slug": ["iexact"],
"base_dn": ["iexact"], "base_dn": ["iexact"],
"search_group__group_uuid": ["iexact"],
"search_group__name": ["iexact"],
"certificate__kp_uuid": ["iexact"], "certificate__kp_uuid": ["iexact"],
"certificate__name": ["iexact"], "certificate__name": ["iexact"],
"tls_server_name": ["iexact"], "tls_server_name": ["iexact"],
@ -102,6 +95,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
"base_dn", "base_dn",
"bind_flow_slug", "bind_flow_slug",
"application_slug", "application_slug",
"search_group",
"certificate", "certificate",
"tls_server_name", "tls_server_name",
"uid_start_number", "uid_start_number",
@ -122,33 +116,3 @@ class LDAPOutpostConfigViewSet(ListModelMixin, GenericViewSet):
ordering = ["name"] ordering = ["name"]
search_fields = ["name"] search_fields = ["name"]
filterset_fields = ["name"] filterset_fields = ["name"]
class LDAPCheckAccessSerializer(PassiveSerializer):
has_search_permission = BooleanField(required=False)
access = PolicyTestResultSerializer()
@extend_schema(
request=None,
parameters=[OpenApiParameter("app_slug", OpenApiTypes.STR)],
responses={
200: LDAPCheckAccessSerializer(),
},
operation_id="outposts_ldap_access_check",
)
@action(detail=True)
def check_access(self, request: Request, pk) -> Response:
"""Check access to a single application by slug"""
provider = get_object_or_404(LDAPProvider, pk=pk)
application = get_object_or_404(Application, slug=request.query_params["app_slug"])
engine = PolicyEngine(application, request.user, request)
engine.use_cache = False
engine.build()
result = engine.result
access_response = PolicyResult(result.passing)
response = self.LDAPCheckAccessSerializer(
instance={
"has_search_permission": request.user.has_perm("search_full_directory", provider),
"access": access_response,
}
)
return Response(response.data)

View File

@ -1,66 +0,0 @@
# Generated by Django 5.0.7 on 2024-07-25 14:59
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db import migrations
def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from authentik.core.models import User
from django.apps import apps as real_apps
from django.contrib.auth.management import create_permissions
from guardian.shortcuts import UserObjectPermission
db_alias = schema_editor.connection.alias
# Permissions are only created _after_ migrations are run
# - https://github.com/django/django/blob/43cdfa8b20e567a801b7d0a09ec67ddd062d5ea4/django/contrib/auth/apps.py#L19
# - https://stackoverflow.com/a/72029063/1870445
create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
Permission = apps.get_model("auth", "Permission")
UserObjectPermission = apps.get_model("guardian", "UserObjectPermission")
ContentType = apps.get_model("contenttypes", "ContentType")
new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory")
ct = ContentType.objects.using(db_alias).get(
app_label="authentik_providers_ldap",
model="ldapprovider",
)
for provider in LDAPProvider.objects.using(db_alias).all():
if not provider.search_group:
continue
for user in provider.search_group.users.using(db_alias).all():
UserObjectPermission.objects.using(db_alias).create(
user=user,
permission=new_prem,
object_pk=provider.pk,
content_type=ct,
)
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
("guardian", "0002_generic_permissions_index"),
]
operations = [
migrations.AlterModelOptions(
name="ldapprovider",
options={
"permissions": [("search_full_directory", "Search full LDAP directory")],
"verbose_name": "LDAP Provider",
"verbose_name_plural": "LDAP Providers",
},
),
migrations.RunPython(migrate_search_group),
migrations.RemoveField(
model_name="ldapprovider",
name="search_group",
),
]

View File

@ -7,7 +7,7 @@ from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from authentik.core.models import BackchannelProvider from authentik.core.models import BackchannelProvider, Group
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.outposts.models import OutpostModel from authentik.outposts.models import OutpostModel
@ -27,6 +27,17 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
help_text=_("DN under which objects are accessible."), help_text=_("DN under which objects are accessible."),
) )
search_group = models.ForeignKey(
Group,
null=True,
default=None,
on_delete=models.SET_DEFAULT,
help_text=_(
"Users in this group can do search queries. "
"If not set, every user can execute search queries."
),
)
tls_server_name = models.TextField( tls_server_name = models.TextField(
default="", default="",
blank=True, blank=True,
@ -102,6 +113,3 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
class Meta: class Meta:
verbose_name = _("LDAP Provider") verbose_name = _("LDAP Provider")
verbose_name_plural = _("LDAP Providers") verbose_name_plural = _("LDAP Providers")
permissions = [
("search_full_directory", _("Search full LDAP directory")),
]

View File

@ -105,7 +105,7 @@ class ScopeMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-scope-form" return "ak-property-mapping-scope-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -62,7 +62,7 @@ urlpatterns = [
api_urlpatterns = [ api_urlpatterns = [
("providers/oauth2", OAuth2ProviderViewSet), ("providers/oauth2", OAuth2ProviderViewSet),
("propertymappings/provider/scope", ScopeMappingViewSet), ("propertymappings/scope", ScopeMappingViewSet),
("oauth2/authorization_codes", AuthorizationCodeViewSet), ("oauth2/authorization_codes", AuthorizationCodeViewSet),
("oauth2/refresh_tokens", RefreshTokenViewSet), ("oauth2/refresh_tokens", RefreshTokenViewSet),
("oauth2/access_tokens", AccessTokenViewSet), ("oauth2/access_tokens", AccessTokenViewSet),

View File

@ -433,21 +433,20 @@ class TokenParams:
app = Application.objects.filter(provider=self.provider).first() app = Application.objects.filter(provider=self.provider).first()
if not app or not app.provider: if not app or not app.provider:
raise TokenError("invalid_grant") raise TokenError("invalid_grant")
with audit_ignore(): self.user, _ = User.objects.update_or_create(
self.user, _ = User.objects.update_or_create( # trim username to ensure the entire username is max 150 chars
# trim username to ensure the entire username is max 150 chars # (22 chars being the length of the "template")
# (22 chars being the length of the "template") username=f"ak-{self.provider.name[:150-22]}-client_credentials",
username=f"ak-{self.provider.name[:150-22]}-client_credentials", defaults={
defaults={ "attributes": {
"attributes": { USER_ATTRIBUTE_GENERATED: True,
USER_ATTRIBUTE_GENERATED: True,
},
"last_login": timezone.now(),
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
}, },
) "last_login": timezone.now(),
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.__check_policy_access(app, request) self.__check_policy_access(app, request)
Event.new( Event.new(

View File

@ -28,7 +28,7 @@ class ProxyDockerController(DockerController):
labels = super()._get_labels() labels = super()._get_labels()
labels["traefik.enable"] = "true" labels["traefik.enable"] = "true"
labels[f"traefik.http.routers.{traefik_name}-router.rule"] = ( labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
f"({' || '.join([f'Host({host})' for host in hosts])})" f"({' || '.join([f'Host(`{host}`)' for host in hosts])})"
f" && PathPrefix(`/outpost.goauthentik.io`)" f" && PathPrefix(`/outpost.goauthentik.io`)"
) )
labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true" labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"

View File

@ -154,7 +154,6 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
responses={ responses={
200: RadiusCheckAccessSerializer(), 200: RadiusCheckAccessSerializer(),
}, },
operation_id="outposts_radius_access_check",
) )
@action(detail=True) @action(detail=True)
def check_access(self, request: Request, pk) -> Response: def check_access(self, request: Request, pk) -> Response:

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_radius", "0003_radiusproviderpropertymapping"),
]
operations = [
migrations.AlterModelOptions(
name="radiusproviderpropertymapping",
options={
"verbose_name": "Radius Provider Property Mapping",
"verbose_name_plural": "Radius Provider Property Mappings",
},
),
]

View File

@ -70,7 +70,7 @@ class RadiusProviderPropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-radius-form" return "ak-property-mapping-radius-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -81,8 +81,8 @@ class RadiusProviderPropertyMapping(PropertyMapping):
return RadiusProviderPropertyMappingSerializer return RadiusProviderPropertyMappingSerializer
def __str__(self): def __str__(self):
return f"Radius Provider Property Mapping {self.name}" return f"Radius Property Mapping {self.name}"
class Meta: class Meta:
verbose_name = _("Radius Provider Property Mapping") verbose_name = _("Radius Property Mapping")
verbose_name_plural = _("Radius Provider Property Mappings") verbose_name_plural = _("Radius Property Mappings")

View File

@ -7,7 +7,7 @@ from authentik.providers.radius.api.providers import (
) )
api_urlpatterns = [ api_urlpatterns = [
("propertymappings/provider/radius", RadiusProviderPropertyMappingViewSet), ("propertymappings/radius", RadiusProviderPropertyMappingViewSet),
("outposts/radius", RadiusOutpostConfigViewSet, "radiusprovideroutpost"), ("outposts/radius", RadiusOutpostConfigViewSet, "radiusprovideroutpost"),
("providers/radius", RadiusProviderViewSet), ("providers/radius", RadiusProviderViewSet),
] ]

View File

@ -133,17 +133,6 @@ class SAMLProviderSerializer(ProviderSerializer):
except Provider.application.RelatedObjectDoesNotExist: except Provider.application.RelatedObjectDoesNotExist:
return "-" return "-"
def validate(self, attrs: dict):
if attrs.get("signing_kp"):
if not attrs.get("sign_assertion") and not attrs.get("sign_response"):
raise ValidationError(
_(
"With a signing keypair selected, at least one of 'Sign assertion' "
"and 'Sign Response' must be selected."
)
)
return super().validate(attrs)
class Meta: class Meta:
model = SAMLProvider model = SAMLProvider
fields = ProviderSerializer.Meta.fields + [ fields = ProviderSerializer.Meta.fields + [
@ -159,9 +148,6 @@ class SAMLProviderSerializer(ProviderSerializer):
"signature_algorithm", "signature_algorithm",
"signing_kp", "signing_kp",
"verification_kp", "verification_kp",
"encryption_kp",
"sign_assertion",
"sign_response",
"sp_binding", "sp_binding",
"default_relay_state", "default_relay_state",
"url_download_metadata", "url_download_metadata",

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_saml", "0014_alter_samlprovider_digest_algorithm_and_more"),
]
operations = [
migrations.AlterModelOptions(
name="samlpropertymapping",
options={
"verbose_name": "SAML Provider Property Mapping",
"verbose_name_plural": "SAML Provider Property Mappings",
},
),
]

View File

@ -1,39 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-15 14:52
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_crypto", "0004_alter_certificatekeypair_name"),
("authentik_providers_saml", "0015_alter_samlpropertymapping_options"),
]
operations = [
migrations.AddField(
model_name="samlprovider",
name="encryption_kp",
field=models.ForeignKey(
blank=True,
default=None,
help_text="When selected, incoming assertions are encrypted by the IdP using the public key of the encryption keypair. The assertion is decrypted by the SP using the the private key.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="+",
to="authentik_crypto.certificatekeypair",
verbose_name="Encryption Keypair",
),
),
migrations.AddField(
model_name="samlprovider",
name="sign_assertion",
field=models.BooleanField(default=True),
),
migrations.AddField(
model_name="samlprovider",
name="sign_response",
field=models.BooleanField(default=False),
),
]

View File

@ -144,28 +144,11 @@ class SAMLProvider(Provider):
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
verbose_name=_("Signing Keypair"), verbose_name=_("Signing Keypair"),
) )
encryption_kp = models.ForeignKey(
CertificateKeyPair,
default=None,
null=True,
blank=True,
help_text=_(
"When selected, incoming assertions are encrypted by the IdP using the public "
"key of the encryption keypair. The assertion is decrypted by the SP using the "
"the private key."
),
on_delete=models.SET_NULL,
verbose_name=_("Encryption Keypair"),
related_name="+",
)
default_relay_state = models.TextField( default_relay_state = models.TextField(
default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins") default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins")
) )
sign_assertion = models.BooleanField(default=True)
sign_response = models.BooleanField(default=False)
@property @property
def launch_url(self) -> str | None: def launch_url(self) -> str | None:
"""Use IDP-Initiated SAML flow as launch URL""" """Use IDP-Initiated SAML flow as launch URL"""
@ -208,7 +191,7 @@ class SAMLPropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-saml-form" return "ak-property-mapping-saml-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -221,8 +204,8 @@ class SAMLPropertyMapping(PropertyMapping):
return f"{self.name} ({name})" return f"{self.name} ({name})"
class Meta: class Meta:
verbose_name = _("SAML Provider Property Mapping") verbose_name = _("SAML Property Mapping")
verbose_name_plural = _("SAML Provider Property Mappings") verbose_name_plural = _("SAML Property Mappings")
class SAMLProviderImportModel(CreatableType, Provider): class SAMLProviderImportModel(CreatableType, Provider):

View File

@ -18,11 +18,7 @@ from authentik.providers.saml.processors.authn_request_parser import AuthNReques
from authentik.providers.saml.utils import get_random_id from authentik.providers.saml.utils import get_random_id
from authentik.providers.saml.utils.time import get_time_string from authentik.providers.saml.utils.time import get_time_string
from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME
from authentik.sources.saml.exceptions import ( from authentik.sources.saml.exceptions import InvalidSignature, UnsupportedNameIDFormat
InvalidEncryption,
InvalidSignature,
UnsupportedNameIDFormat,
)
from authentik.sources.saml.processors.constants import ( from authentik.sources.saml.processors.constants import (
DIGEST_ALGORITHM_TRANSLATION_MAP, DIGEST_ALGORITHM_TRANSLATION_MAP,
NS_MAP, NS_MAP,
@ -260,17 +256,9 @@ class AssertionProcessor:
assertion, assertion,
xmlsec.constants.TransformExclC14N, xmlsec.constants.TransformExclC14N,
sign_algorithm_transform, sign_algorithm_transform,
ns=xmlsec.constants.DSigNs, ns="ds", # type: ignore
) )
assertion.append(signature) assertion.append(signature)
if self.provider.encryption_kp:
encryption = xmlsec.template.encrypted_data_create(
assertion,
xmlsec.constants.TransformAes128Cbc,
self._assertion_id,
ns=xmlsec.constants.DSigNs,
)
assertion.append(encryption)
assertion.append(self.get_assertion_subject()) assertion.append(self.get_assertion_subject())
assertion.append(self.get_assertion_conditions()) assertion.append(self.get_assertion_conditions())
@ -298,86 +286,41 @@ class AssertionProcessor:
response.append(self.get_assertion()) response.append(self.get_assertion())
return response return response
def _sign(self, element: Element):
"""Sign an XML element based on the providers' configured signing settings"""
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
)
xmlsec.tree.add_ids(element, ["ID"])
signature_node = xmlsec.tree.find_node(element, xmlsec.constants.NodeSignature)
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + self._assertion_id,
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
key_info = xmlsec.template.ensure_key_info(signature_node)
xmlsec.template.add_x509_data(key_info)
ctx = xmlsec.SignatureContext()
key = xmlsec.Key.from_memory(
self.provider.signing_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
None,
)
key.load_cert_from_memory(
self.provider.signing_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
ctx.key = key
try:
ctx.sign(signature_node)
except xmlsec.Error as exc:
raise InvalidSignature() from exc
def _encrypt(self, element: Element, parent: Element):
"""Encrypt SAMLResponse EncryptedAssertion Element"""
manager = xmlsec.KeysManager()
key = xmlsec.Key.from_memory(
self.provider.encryption_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
)
key.load_cert_from_memory(
self.provider.encryption_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
manager.add_key(key)
encryption_context = xmlsec.EncryptionContext(manager)
encryption_context.key = xmlsec.Key.generate(
xmlsec.constants.KeyDataAes, 128, xmlsec.constants.KeyDataTypeSession
)
container = SubElement(parent, f"{{{NS_SAML_ASSERTION}}}EncryptedAssertion")
enc_data = xmlsec.template.encrypted_data_create(
container, xmlsec.Transform.AES128, type=xmlsec.EncryptionType.ELEMENT, ns="xenc"
)
xmlsec.template.encrypted_data_ensure_cipher_value(enc_data)
key_info = xmlsec.template.encrypted_data_ensure_key_info(enc_data, ns="ds")
enc_key = xmlsec.template.add_encrypted_key(key_info, xmlsec.Transform.RSA_OAEP)
xmlsec.template.encrypted_data_ensure_cipher_value(enc_key)
try:
enc_data = encryption_context.encrypt_xml(enc_data, element)
except xmlsec.Error as exc:
raise InvalidEncryption() from exc
parent.remove(enc_data)
container.append(enc_data)
def build_response(self) -> str: def build_response(self) -> str:
"""Build string XML Response and sign if signing is enabled.""" """Build string XML Response and sign if signing is enabled."""
root_response = self.get_response() root_response = self.get_response()
if self.provider.signing_kp: if self.provider.signing_kp:
if self.provider.sign_assertion: digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0] self.provider.digest_algorithm, xmlsec.constants.TransformSha1
self._sign(assertion) )
if self.provider.sign_response:
response = root_response.xpath("//samlp:Response", namespaces=NS_MAP)[0]
self._sign(response)
if self.provider.encryption_kp:
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0] assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
self._encrypt(assertion, root_response) xmlsec.tree.add_ids(assertion, ["ID"])
signature_node = xmlsec.tree.find_node(assertion, xmlsec.constants.NodeSignature)
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + self._assertion_id,
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
key_info = xmlsec.template.ensure_key_info(signature_node)
xmlsec.template.add_x509_data(key_info)
ctx = xmlsec.SignatureContext()
key = xmlsec.Key.from_memory(
self.provider.signing_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
None,
)
key.load_cert_from_memory(
self.provider.signing_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
ctx.key = key
try:
ctx.sign(signature_node)
except xmlsec.Error as exc:
raise InvalidSignature() from exc
return etree.tostring(root_response).decode("utf-8") # nosec return etree.tostring(root_response).decode("utf-8") # nosec

View File

@ -126,7 +126,7 @@ class MetadataProcessor:
entity_descriptor, entity_descriptor,
xmlsec.constants.TransformExclC14N, xmlsec.constants.TransformExclC14N,
sign_algorithm_transform, sign_algorithm_transform,
ns=xmlsec.constants.DSigNs, ns="ds", # type: ignore
) )
entity_descriptor.append(signature) entity_descriptor.append(signature)

View File

@ -8,7 +8,7 @@ from rest_framework.test import APITestCase
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import FlowDesignation from authentik.flows.models import FlowDesignation
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture from authentik.lib.tests.utils import load_fixture
@ -29,52 +29,12 @@ class TestSAMLProviderAPI(APITestCase):
name=generate_id(), name=generate_id(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
) )
response = self.client.get(
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
)
self.assertEqual(200, response.status_code)
Application.objects.create(name=generate_id(), provider=provider, slug=generate_id()) Application.objects.create(name=generate_id(), provider=provider, slug=generate_id())
response = self.client.get( response = self.client.get(
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}), reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
def test_create_validate_signing_kp(self):
"""Test create"""
cert = create_test_cert()
response = self.client.post(
reverse("authentik_api:samlprovider-list"),
data={
"name": generate_id(),
"authorization_flow": create_test_flow().pk,
"acs_url": "http://localhost",
"signing_kp": cert.pk,
},
)
self.assertEqual(400, response.status_code)
self.assertJSONEqual(
response.content,
{
"non_field_errors": [
(
"With a signing keypair selected, at least one "
"of 'Sign assertion' and 'Sign Response' must be selected."
)
]
},
)
response = self.client.post(
reverse("authentik_api:samlprovider-list"),
data={
"name": generate_id(),
"authorization_flow": create_test_flow().pk,
"acs_url": "http://localhost",
"signing_kp": cert.pk,
"sign_assertion": True,
},
)
self.assertEqual(201, response.status_code)
def test_metadata(self): def test_metadata(self):
"""Test metadata export (normal)""" """Test metadata export (normal)"""
self.client.logout() self.client.logout()

View File

@ -78,12 +78,12 @@ class TestAuthNRequest(TestCase):
@apply_blueprint("system/providers-saml.yaml") @apply_blueprint("system/providers-saml.yaml")
def setUp(self): def setUp(self):
self.cert = create_test_cert() cert = create_test_cert()
self.provider: SAMLProvider = SAMLProvider.objects.create( self.provider: SAMLProvider = SAMLProvider.objects.create(
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
acs_url="http://testserver/source/saml/provider/acs/", acs_url="http://testserver/source/saml/provider/acs/",
signing_kp=self.cert, signing_kp=cert,
verification_kp=self.cert, verification_kp=cert,
) )
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all()) self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save() self.provider.save()
@ -91,8 +91,8 @@ class TestAuthNRequest(TestCase):
slug="provider", slug="provider",
issuer="authentik", issuer="authentik",
pre_authentication_flow=create_test_flow(), pre_authentication_flow=create_test_flow(),
signing_kp=self.cert, signing_kp=cert,
verification_kp=self.cert, verification_kp=cert,
) )
def test_signed_valid(self): def test_signed_valid(self):
@ -112,34 +112,7 @@ class TestAuthNRequest(TestCase):
self.assertEqual(parsed_request.id, request_proc.request_id) self.assertEqual(parsed_request.id, request_proc.request_id)
self.assertEqual(parsed_request.relay_state, "test_state") self.assertEqual(parsed_request.relay_state, "test_state")
def test_request_encrypt(self): def test_request_full_signed(self):
"""Test full SAML Request/Response flow, fully encrypted"""
self.provider.encryption_kp = self.cert
self.provider.save()
self.source.encryption_kp = self.cert
self.source.save()
http_request = get_request("/")
# First create an AuthNRequest
request_proc = RequestProcessor(self.source, http_request, "test_state")
request = request_proc.build_auth_n()
# To get an assertion we need a parsed request (parsed by provider)
parsed_request = AuthNRequestParser(self.provider).parse(
b64encode(request.encode()).decode(), "test_state"
)
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse()
def test_request_signed(self):
"""Test full SAML Request/Response flow, fully signed""" """Test full SAML Request/Response flow, fully signed"""
http_request = get_request("/") http_request = get_request("/")
@ -162,32 +135,6 @@ class TestAuthNRequest(TestCase):
response_parser = ResponseProcessor(self.source, http_request) response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse() response_parser.parse()
def test_request_signed_both(self):
"""Test full SAML Request/Response flow, fully signed"""
self.provider.sign_assertion = True
self.provider.sign_response = True
self.provider.save()
http_request = get_request("/")
# First create an AuthNRequest
request_proc = RequestProcessor(self.source, http_request, "test_state")
request = request_proc.build_auth_n()
# To get an assertion we need a parsed request (parsed by provider)
parsed_request = AuthNRequestParser(self.provider).parse(
b64encode(request.encode()).decode(), "test_state"
)
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse()
def test_request_id_invalid(self): def test_request_id_invalid(self):
"""Test generated AuthNRequest with invalid request ID""" """Test generated AuthNRequest with invalid request ID"""
http_request = get_request("/") http_request = get_request("/")

View File

@ -54,11 +54,7 @@ class TestServiceProviderMetadataParser(TestCase):
request = self.factory.get("/") request = self.factory.get("/")
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor()) metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
schema = etree.XMLSchema( schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
etree.parse(
source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()
) # nosec
)
self.assertTrue(schema.validate(metadata)) self.assertTrue(schema.validate(metadata))
def test_schema_want_authn_requests_signed(self): def test_schema_want_authn_requests_signed(self):

View File

@ -47,9 +47,7 @@ class TestSchema(TestCase):
metadata = lxml_from_string(request) metadata = lxml_from_string(request)
schema = etree.XMLSchema( schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
self.assertTrue(schema.validate(metadata)) self.assertTrue(schema.validate(metadata))
def test_response_schema(self): def test_response_schema(self):
@ -70,7 +68,5 @@ class TestSchema(TestCase):
metadata = lxml_from_string(response) metadata = lxml_from_string(response)
schema = etree.XMLSchema( schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
self.assertTrue(schema.validate(metadata)) self.assertTrue(schema.validate(metadata))

View File

@ -44,6 +44,6 @@ urlpatterns = [
] ]
api_urlpatterns = [ api_urlpatterns = [
("propertymappings/provider/saml", SAMLPropertyMappingViewSet), ("propertymappings/saml", SAMLPropertyMappingViewSet),
("providers/saml", SAMLProviderViewSet), ("providers/saml", SAMLProviderViewSet),
] ]

View File

@ -6,7 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects from authentik.providers.scim.tasks import scim_sync
class SCIMProviderSerializer(ProviderSerializer): class SCIMProviderSerializer(ProviderSerializer):
@ -42,4 +42,3 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
search_fields = ["name", "url"] search_fields = ["name", "url"]
ordering = ["name", "url"] ordering = ["name", "url"]
sync_single_task = scim_sync sync_single_task = scim_sync
sync_objects_task = scim_sync_objects

View File

@ -1,7 +1,5 @@
"""Group client""" """Group client"""
from itertools import batched
from pydantic import ValidationError from pydantic import ValidationError
from pydanticscim.group import GroupMember from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp, PatchOperation from pydanticscim.responses import PatchOp, PatchOperation
@ -58,22 +56,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
if not scim_group.externalId: if not scim_group.externalId:
scim_group.externalId = str(obj.pk) scim_group.externalId = str(obj.pk)
if not self._config.patch.supported: users = list(obj.users.order_by("id").values_list("id", flat=True))
users = list(obj.users.order_by("id").values_list("id", flat=True)) connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
connections = SCIMProviderUser.objects.filter( members = []
provider=self.provider, user__pk__in=users for user in connections:
) members.append(
members = [] GroupMember(
for user in connections: value=user.scim_id,
members.append(
GroupMember(
value=user.scim_id,
)
) )
if members: )
scim_group.members = members if members:
else: scim_group.members = members
del scim_group.members
return scim_group return scim_group
def delete(self, obj: Group): def delete(self, obj: Group):
@ -100,19 +93,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
scim_id = response.get("id") scim_id = response.get("id")
if not scim_id or scim_id == "": if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`") raise StopSync("SCIM Response with missing or invalid `id`")
connection = SCIMProviderGroup.objects.create( return SCIMProviderGroup.objects.create(
provider=self.provider, group=group, scim_id=scim_id provider=self.provider, group=group, scim_id=scim_id
) )
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup): def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group""" """Update existing group"""
scim_group = self.to_schema(group, connection) scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id scim_group.id = connection.scim_id
try: try:
self._request( return self._request(
"PUT", "PUT",
f"/Groups/{connection.scim_id}", f"/Groups/{connection.scim_id}",
json=scim_group.model_dump( json=scim_group.model_dump(
@ -120,8 +110,6 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True, exclude_unset=True,
), ),
) )
users = list(group.users.order_by("id").values_list("id", flat=True))
return self._patch_add_users(group, users)
except NotFoundSyncException: except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group # Resource missing is handled by self.write, which will re-create the group
raise raise
@ -164,18 +152,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
group_id: str, group_id: str,
*ops: PatchOperation, *ops: PatchOperation,
): ):
chunk_size = self._config.bulk.maxOperations req = PatchRequest(Operations=ops)
if chunk_size < 1: self._request(
chunk_size = len(ops) "PATCH",
for chunk in batched(ops, chunk_size): f"/Groups/{group_id}",
req = PatchRequest(Operations=list(chunk)) json=req.model_dump(
self._request( mode="json",
"PATCH", ),
f"/Groups/{group_id}", )
json=req.model_dump(
mode="json",
),
)
def _patch_add_users(self, group: Group, users_set: set[int]): def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group""" """Add users in users_set to group"""
@ -196,14 +180,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
return return
self._patch( self._patch(
scim_group.scim_id, scim_group.scim_id,
*[ PatchOperation(
PatchOperation( op=PatchOp.add,
op=PatchOp.add, path="members",
path="members", value=[{"value": x} for x in user_ids],
value=[{"value": x}], ),
)
for x in user_ids
],
) )
def _patch_remove_users(self, group: Group, users_set: set[int]): def _patch_remove_users(self, group: Group, users_set: set[int]):
@ -225,12 +206,9 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
return return
self._patch( self._patch(
scim_group.scim_id, scim_group.scim_id,
*[ PatchOperation(
PatchOperation( op=PatchOp.remove,
op=PatchOp.remove, path="members",
path="members", value=[{"value": x} for x in user_ids],
value=[{"value": x}], ),
)
for x in user_ids
],
) )

View File

@ -1,11 +1,9 @@
"""Custom SCIM schemas""" """Custom SCIM schemas"""
from pydantic import Field
from pydanticscim.group import Group as BaseGroup from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchRequest as BasePatchRequest from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk as BaseBulk from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import ( from pydanticscim.service_provider import (
ServiceProviderConfiguration as BaseServiceProviderConfiguration, ServiceProviderConfiguration as BaseServiceProviderConfiguration,
) )
@ -31,16 +29,10 @@ class Group(BaseGroup):
meta: dict | None = None meta: dict | None = None
class Bulk(BaseBulk):
maxOperations: int = Field()
class ServiceProviderConfiguration(BaseServiceProviderConfiguration): class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""ServiceProviderConfig with fallback""" """ServiceProviderConfig with fallback"""
_is_fallback: bool | None = False _is_fallback: bool | None = False
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
@property @property
def is_fallback(self) -> bool: def is_fallback(self) -> bool:
@ -53,7 +45,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""Get default configuration, which doesn't support any optional features as fallback""" """Get default configuration, which doesn't support any optional features as fallback"""
return ServiceProviderConfiguration( return ServiceProviderConfiguration(
patch=Patch(supported=False), patch=Patch(supported=False),
bulk=Bulk(supported=False, maxOperations=0), bulk=Bulk(supported=False),
filter=Filter(supported=False), filter=Filter(supported=False),
changePassword=ChangePassword(supported=False), changePassword=ChangePassword(supported=False),
sort=Sort(supported=False), sort=Sort(supported=False),

Some files were not shown because too many files have changed in this diff Show More