Compare commits
	
		
			12 Commits
		
	
	
		
			version/20
			...
			core/b2c-i
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 195091ed3b | |||
| 4de3f1f4b8 | |||
| af4f1b3421 | |||
| 77b816ad51 | |||
| b28dd485a0 | |||
| 4701389745 | |||
| 0d0097e956 | |||
| b42eb0706d | |||
| 3afe386e18 | |||
| 34dd9c0b63 | |||
| b2f2fd241d | |||
| 828f477548 | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2024.4.3 | current_version = 2024.2.2 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
| @ -21,8 +21,6 @@ optional_value = final | |||||||
|  |  | ||||||
| [bumpversion:file:schema.yml] | [bumpversion:file:schema.yml] | ||||||
|  |  | ||||||
| [bumpversion:file:blueprints/schema.json] |  | ||||||
|  |  | ||||||
| [bumpversion:file:authentik/__init__.py] | [bumpversion:file:authentik/__init__.py] | ||||||
|  |  | ||||||
| [bumpversion:file:internal/constants/constants.go] | [bumpversion:file:internal/constants/constants.go] | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/FUNDING.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/FUNDING.yml
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | |||||||
| custom: https://goauthentik.io/pricing/ | github: [BeryJu] | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower() | |||||||
| branch_name = os.environ["GITHUB_REF"] | branch_name = os.environ["GITHUB_REF"] | ||||||
| if os.environ.get("GITHUB_HEAD_REF", "") != "": | if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||||
|     branch_name = os.environ["GITHUB_HEAD_REF"] |     branch_name = os.environ["GITHUB_HEAD_REF"] | ||||||
| safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-").replace("'", "-") | safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-") | ||||||
|  |  | ||||||
| image_names = os.getenv("IMAGE_NAME").split(",") | image_names = os.getenv("IMAGE_NAME").split(",") | ||||||
| image_arch = os.getenv("IMAGE_ARCH") or None | image_arch = os.getenv("IMAGE_ARCH") or None | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -16,25 +16,25 @@ runs: | |||||||
|         sudo apt-get update |         sudo apt-get update | ||||||
|         sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext |         sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext | ||||||
|     - name: Setup python and restore poetry |     - name: Setup python and restore poetry | ||||||
|       uses: actions/setup-python@v5 |       uses: actions/setup-python@v4 | ||||||
|       with: |       with: | ||||||
|         python-version-file: "pyproject.toml" |         python-version-file: "pyproject.toml" | ||||||
|         cache: "poetry" |         cache: "poetry" | ||||||
|     - name: Setup node |     - name: Setup node | ||||||
|       uses: actions/setup-node@v4 |       uses: actions/setup-node@v3 | ||||||
|       with: |       with: | ||||||
|         node-version-file: web/package.json |         node-version-file: web/package.json | ||||||
|         cache: "npm" |         cache: "npm" | ||||||
|         cache-dependency-path: web/package-lock.json |         cache-dependency-path: web/package-lock.json | ||||||
|     - name: Setup go |     - name: Setup go | ||||||
|       uses: actions/setup-go@v5 |       uses: actions/setup-go@v4 | ||||||
|       with: |       with: | ||||||
|         go-version-file: "go.mod" |         go-version-file: "go.mod" | ||||||
|     - name: Setup dependencies |     - name: Setup dependencies | ||||||
|       shell: bash |       shell: bash | ||||||
|       run: | |       run: | | ||||||
|         export PSQL_TAG=${{ inputs.postgresql_version }} |         export PSQL_TAG=${{ inputs.postgresql_version }} | ||||||
|         docker compose -f .github/actions/setup/docker-compose.yml up -d |         docker-compose -f .github/actions/setup/docker-compose.yml up -d | ||||||
|         poetry install |         poetry install | ||||||
|         cd web && npm ci |         cd web && npm ci | ||||||
|     - name: Generate config |     - name: Generate config | ||||||
|  | |||||||
							
								
								
									
										65
									
								
								.github/workflows/api-py-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										65
									
								
								.github/workflows/api-py-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,65 +0,0 @@ | |||||||
| name: authentik-api-py-publish |  | ||||||
| on: |  | ||||||
|   push: |  | ||||||
|     branches: [main] |  | ||||||
|     paths: |  | ||||||
|       - "schema.yml" |  | ||||||
|   workflow_dispatch: |  | ||||||
| jobs: |  | ||||||
|   build: |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     permissions: |  | ||||||
|       id-token: write |  | ||||||
|     steps: |  | ||||||
|       - id: generate_token |  | ||||||
|         uses: tibdex/github-app-token@v2 |  | ||||||
|         with: |  | ||||||
|           app_id: ${{ secrets.GH_APP_ID }} |  | ||||||
|           private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           token: ${{ steps.generate_token.outputs.token }} |  | ||||||
|       - name: Install poetry & deps |  | ||||||
|         shell: bash |  | ||||||
|         run: | |  | ||||||
|           pipx install poetry || true |  | ||||||
|           sudo apt-get update |  | ||||||
|           sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext |  | ||||||
|       - name: Setup python and restore poetry |  | ||||||
|         uses: actions/setup-python@v5 |  | ||||||
|         with: |  | ||||||
|           python-version-file: "pyproject.toml" |  | ||||||
|           cache: "poetry" |  | ||||||
|       - name: Generate API Client |  | ||||||
|         run: make gen-client-py |  | ||||||
|       - name: Publish package |  | ||||||
|         working-directory: gen-py-api/ |  | ||||||
|         run: | |  | ||||||
|           poetry build |  | ||||||
|       - name: Publish package to PyPI |  | ||||||
|         uses: pypa/gh-action-pypi-publish@release/v1 |  | ||||||
|         with: |  | ||||||
|           packages-dir: gen-py-api/dist/ |  | ||||||
|       # We can't easily upgrade the API client being used due to poetry being poetry |  | ||||||
|       # so we'll have to rely on dependabot |  | ||||||
|       # - name: Upgrade / |  | ||||||
|       #   run: | |  | ||||||
|       #     export VERSION=$(cd gen-py-api && poetry version -s) |  | ||||||
|       #     poetry add "authentik_client=$VERSION" --allow-prereleases --lock |  | ||||||
|       # - uses: peter-evans/create-pull-request@v6 |  | ||||||
|       #   id: cpr |  | ||||||
|       #   with: |  | ||||||
|       #     token: ${{ steps.generate_token.outputs.token }} |  | ||||||
|       #     branch: update-root-api-client |  | ||||||
|       #     commit-message: "root: bump API Client version" |  | ||||||
|       #     title: "root: bump API Client version" |  | ||||||
|       #     body: "root: bump API Client version" |  | ||||||
|       #     delete-branch: true |  | ||||||
|       #     signoff: true |  | ||||||
|       #     # ID from https://api.github.com/users/authentik-automation[bot] |  | ||||||
|       #     author: authentik-automation[bot] <135050075+authentik-automation[bot]@users.noreply.github.com> |  | ||||||
|       # - uses: peter-evans/enable-pull-request-automerge@v3 |  | ||||||
|       #   with: |  | ||||||
|       #     token: ${{ steps.generate_token.outputs.token }} |  | ||||||
|       #     pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} |  | ||||||
|       #     merge-method: squash |  | ||||||
							
								
								
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -160,8 +160,6 @@ jobs: | |||||||
|             glob: tests/e2e/test_provider_ldap* tests/e2e/test_source_ldap* |             glob: tests/e2e/test_provider_ldap* tests/e2e/test_source_ldap* | ||||||
|           - name: radius |           - name: radius | ||||||
|             glob: tests/e2e/test_provider_radius* |             glob: tests/e2e/test_provider_radius* | ||||||
|           - name: scim |  | ||||||
|             glob: tests/e2e/test_source_scim* |  | ||||||
|           - name: flows |           - name: flows | ||||||
|             glob: tests/e2e/test_flows* |             glob: tests/e2e/test_flows* | ||||||
|     steps: |     steps: | ||||||
| @ -170,7 +168,7 @@ jobs: | |||||||
|         uses: ./.github/actions/setup |         uses: ./.github/actions/setup | ||||||
|       - name: Setup e2e env (chrome, etc) |       - name: Setup e2e env (chrome, etc) | ||||||
|         run: | |         run: | | ||||||
|           docker compose -f tests/e2e/docker-compose.yml up -d |           docker-compose -f tests/e2e/docker-compose.yml up -d | ||||||
|       - id: cache-web |       - id: cache-web | ||||||
|         uses: actions/cache@v4 |         uses: actions/cache@v4 | ||||||
|         with: |         with: | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							| @ -34,13 +34,6 @@ jobs: | |||||||
|       - name: Eslint |       - name: Eslint | ||||||
|         working-directory: ${{ matrix.project }}/ |         working-directory: ${{ matrix.project }}/ | ||||||
|         run: npm run lint |         run: npm run lint | ||||||
|   lint-lockfile: |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     steps: |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|       - working-directory: web/ |  | ||||||
|         run: | |  | ||||||
|           [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ] |  | ||||||
|   lint-build: |   lint-build: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -102,7 +95,6 @@ jobs: | |||||||
|         run: npm run lit-analyse |         run: npm run lit-analyse | ||||||
|   ci-web-mark: |   ci-web-mark: | ||||||
|     needs: |     needs: | ||||||
|       - lint-lockfile |  | ||||||
|       - lint-eslint |       - lint-eslint | ||||||
|       - lint-prettier |       - lint-prettier | ||||||
|       - lint-lit-analyse |       - lint-lit-analyse | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -12,13 +12,6 @@ on: | |||||||
|       - version-* |       - version-* | ||||||
|  |  | ||||||
| jobs: | jobs: | ||||||
|   lint-lockfile: |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     steps: |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|       - working-directory: website/ |  | ||||||
|         run: | |  | ||||||
|           [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ] |  | ||||||
|   lint-prettier: |   lint-prettier: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -69,7 +62,6 @@ jobs: | |||||||
|         run: npm run ${{ matrix.job }} |         run: npm run ${{ matrix.job }} | ||||||
|   ci-website-mark: |   ci-website-mark: | ||||||
|     needs: |     needs: | ||||||
|       - lint-lockfile |  | ||||||
|       - lint-prettier |       - lint-prettier | ||||||
|       - test |       - test | ||||||
|       - build |       - build | ||||||
|  | |||||||
							
								
								
									
										43
									
								
								.github/workflows/gen-update-webauthn-mds.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										43
									
								
								.github/workflows/gen-update-webauthn-mds.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,43 +0,0 @@ | |||||||
| name: authentik-gen-update-webauthn-mds |  | ||||||
| on: |  | ||||||
|   workflow_dispatch: |  | ||||||
|   schedule: |  | ||||||
|     - cron: '30 1 1,15 * *' |  | ||||||
|  |  | ||||||
| env: |  | ||||||
|   POSTGRES_DB: authentik |  | ||||||
|   POSTGRES_USER: authentik |  | ||||||
|   POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77" |  | ||||||
|  |  | ||||||
| jobs: |  | ||||||
|   build: |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     steps: |  | ||||||
|       - id: generate_token |  | ||||||
|         uses: tibdex/github-app-token@v2 |  | ||||||
|         with: |  | ||||||
|           app_id: ${{ secrets.GH_APP_ID }} |  | ||||||
|           private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           token: ${{ steps.generate_token.outputs.token }} |  | ||||||
|       - name: Setup authentik env |  | ||||||
|         uses: ./.github/actions/setup |  | ||||||
|       - run: poetry run ak update_webauthn_mds |  | ||||||
|       - uses: peter-evans/create-pull-request@v6 |  | ||||||
|         id: cpr |  | ||||||
|         with: |  | ||||||
|           token: ${{ steps.generate_token.outputs.token }} |  | ||||||
|           branch: update-fido-mds-client |  | ||||||
|           commit-message: "stages/authenticator_webauthn: Update FIDO MDS3 & Passkey aaguid blobs" |  | ||||||
|           title: "stages/authenticator_webauthn: Update FIDO MDS3 & Passkey aaguid blobs" |  | ||||||
|           body: "stages/authenticator_webauthn: Update FIDO MDS3 & Passkey aaguid blobs" |  | ||||||
|           delete-branch: true |  | ||||||
|           signoff: true |  | ||||||
|           # ID from https://api.github.com/users/authentik-automation[bot] |  | ||||||
|           author: authentik-automation[bot] <135050075+authentik-automation[bot]@users.noreply.github.com> |  | ||||||
|       - uses: peter-evans/enable-pull-request-automerge@v3 |  | ||||||
|         with: |  | ||||||
|           token: ${{ steps.generate_token.outputs.token }} |  | ||||||
|           pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} |  | ||||||
|           merge-method: squash |  | ||||||
							
								
								
									
										8
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -157,10 +157,10 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|           echo "PG_PASS=$(openssl rand -base64 32)" >> .env |           echo "PG_PASS=$(openssl rand -base64 32)" >> .env | ||||||
|           echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env |           echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env | ||||||
|           docker compose pull -q |           docker-compose pull -q | ||||||
|           docker compose up --no-start |           docker-compose up --no-start | ||||||
|           docker compose start postgresql redis |           docker-compose start postgresql redis | ||||||
|           docker compose run -u root server test-all |           docker-compose run -u root server test-all | ||||||
|   sentry-release: |   sentry-release: | ||||||
|     needs: |     needs: | ||||||
|       - build-server |       - build-server | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							| @ -21,9 +21,9 @@ jobs: | |||||||
|           docker build -t testing:latest . |           docker build -t testing:latest . | ||||||
|           echo "AUTHENTIK_IMAGE=testing" >> .env |           echo "AUTHENTIK_IMAGE=testing" >> .env | ||||||
|           echo "AUTHENTIK_TAG=latest" >> .env |           echo "AUTHENTIK_TAG=latest" >> .env | ||||||
|           docker compose up --no-start |           docker-compose up --no-start | ||||||
|           docker compose start postgresql redis |           docker-compose start postgresql redis | ||||||
|           docker compose run -u root server test-all |           docker-compose run -u root server test-all | ||||||
|       - id: generate_token |       - id: generate_token | ||||||
|         uses: tibdex/github-app-token@v2 |         uses: tibdex/github-app-token@v2 | ||||||
|         with: |         with: | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/repo-stale.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/repo-stale.yml
									
									
									
									
										vendored
									
									
								
							| @ -23,7 +23,7 @@ jobs: | |||||||
|           repo-token: ${{ steps.generate_token.outputs.token }} |           repo-token: ${{ steps.generate_token.outputs.token }} | ||||||
|           days-before-stale: 60 |           days-before-stale: 60 | ||||||
|           days-before-close: 7 |           days-before-close: 7 | ||||||
|           exempt-issue-labels: pinned,security,pr_wanted,enhancement,bug/confirmed,enhancement/confirmed,question,status/reviewing |           exempt-issue-labels: pinned,security,pr_wanted,enhancement,bug/confirmed,enhancement/confirmed,question | ||||||
|           stale-issue-label: wontfix |           stale-issue-label: wontfix | ||||||
|           stale-issue-message: > |           stale-issue-message: > | ||||||
|             This issue has been automatically marked as stale because it has not had |             This issue has been automatically marked as stale because it has not had | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| name: authentik-api-ts-publish | name: authentik-web-api-publish | ||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     branches: [main] |     branches: [main] | ||||||
							
								
								
									
										10
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								Dockerfile
									
									
									
									
									
								
							| @ -38,7 +38,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api | |||||||
| RUN npm run build | RUN npm run build | ||||||
|  |  | ||||||
| # Stage 3: Build go proxy | # Stage 3: Build go proxy | ||||||
| FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.2-bookworm AS go-builder | FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.1-bookworm AS go-builder | ||||||
|  |  | ||||||
| ARG TARGETOS | ARG TARGETOS | ||||||
| ARG TARGETARCH | ARG TARGETARCH | ||||||
| @ -70,10 +70,10 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \ | |||||||
|     GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server |     GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server | ||||||
|  |  | ||||||
| # Stage 4: MaxMind GeoIP | # Stage 4: MaxMind GeoIP | ||||||
| FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip | FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v6.1 as geoip | ||||||
|  |  | ||||||
| ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN" | ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN" | ||||||
| ENV GEOIPUPDATE_VERBOSE="1" | ENV GEOIPUPDATE_VERBOSE="true" | ||||||
| ENV GEOIPUPDATE_ACCOUNT_ID_FILE="/run/secrets/GEOIPUPDATE_ACCOUNT_ID" | ENV GEOIPUPDATE_ACCOUNT_ID_FILE="/run/secrets/GEOIPUPDATE_ACCOUNT_ID" | ||||||
| ENV GEOIPUPDATE_LICENSE_KEY_FILE="/run/secrets/GEOIPUPDATE_LICENSE_KEY" | ENV GEOIPUPDATE_LICENSE_KEY_FILE="/run/secrets/GEOIPUPDATE_LICENSE_KEY" | ||||||
|  |  | ||||||
| @ -84,7 +84,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | |||||||
|     /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" |     /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||||
|  |  | ||||||
| # Stage 5: Python dependencies | # Stage 5: Python dependencies | ||||||
| FROM docker.io/python:3.12.3-slim-bookworm AS python-deps | FROM docker.io/python:3.12.2-slim-bookworm AS python-deps | ||||||
|  |  | ||||||
| WORKDIR /ak-root/poetry | WORKDIR /ak-root/poetry | ||||||
|  |  | ||||||
| @ -110,7 +110,7 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | |||||||
|         poetry install --only=main --no-ansi --no-interaction --no-root" |         poetry install --only=main --no-ansi --no-interaction --no-root" | ||||||
|  |  | ||||||
| # Stage 6: Run | # Stage 6: Run | ||||||
| FROM docker.io/python:3.12.3-slim-bookworm AS final-image | FROM docker.io/python:3.12.2-slim-bookworm AS final-image | ||||||
|  |  | ||||||
| ARG GIT_BUILD_HASH | ARG GIT_BUILD_HASH | ||||||
| ARG VERSION | ARG VERSION | ||||||
|  | |||||||
							
								
								
									
										30
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								Makefile
									
									
									
									
									
								
							| @ -9,7 +9,6 @@ PY_SOURCES = authentik tests scripts lifecycle .github | |||||||
| DOCKER_IMAGE ?= "authentik:test" | DOCKER_IMAGE ?= "authentik:test" | ||||||
|  |  | ||||||
| GEN_API_TS = "gen-ts-api" | GEN_API_TS = "gen-ts-api" | ||||||
| GEN_API_PY = "gen-py-api" |  | ||||||
| GEN_API_GO = "gen-go-api" | GEN_API_GO = "gen-go-api" | ||||||
|  |  | ||||||
| pg_user := $(shell python -m authentik.lib.config postgresql.user 2>/dev/null) | pg_user := $(shell python -m authentik.lib.config postgresql.user 2>/dev/null) | ||||||
| @ -48,10 +47,10 @@ test-go: | |||||||
| test-docker:  ## Run all tests in a docker-compose | test-docker:  ## Run all tests in a docker-compose | ||||||
| 	echo "PG_PASS=$(openssl rand -base64 32)" >> .env | 	echo "PG_PASS=$(openssl rand -base64 32)" >> .env | ||||||
| 	echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env | 	echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env | ||||||
| 	docker compose pull -q | 	docker-compose pull -q | ||||||
| 	docker compose up --no-start | 	docker-compose up --no-start | ||||||
| 	docker compose start postgresql redis | 	docker-compose start postgresql redis | ||||||
| 	docker compose run -u root server test-all | 	docker-compose run -u root server test-all | ||||||
| 	rm -f .env | 	rm -f .env | ||||||
|  |  | ||||||
| test: ## Run the server tests and produce a coverage report (locally) | test: ## Run the server tests and produce a coverage report (locally) | ||||||
| @ -65,7 +64,7 @@ lint-fix:  ## Lint and automatically fix errors in the python source code. Repor | |||||||
| 	codespell -w $(CODESPELL_ARGS) | 	codespell -w $(CODESPELL_ARGS) | ||||||
|  |  | ||||||
| lint: ## Lint the python and golang sources | lint: ## Lint the python and golang sources | ||||||
| 	bandit -r $(PY_SOURCES) -x web/node_modules -x tests/wdio/node_modules -x website/node_modules | 	bandit -r $(PY_SOURCES) -x node_modules | ||||||
| 	golangci-lint run -v | 	golangci-lint run -v | ||||||
|  |  | ||||||
| core-install: | core-install: | ||||||
| @ -138,10 +137,7 @@ gen-clean-ts:  ## Remove generated API client for Typescript | |||||||
| gen-clean-go:  ## Remove generated API client for Go | gen-clean-go:  ## Remove generated API client for Go | ||||||
| 	rm -rf ./${GEN_API_GO}/ | 	rm -rf ./${GEN_API_GO}/ | ||||||
|  |  | ||||||
| gen-clean-py:  ## Remove generated API client for Python | gen-clean: gen-clean-ts gen-clean-go  ## Remove generated API clients | ||||||
| 	rm -rf ./${GEN_API_PY}/ |  | ||||||
|  |  | ||||||
| gen-clean: gen-clean-ts gen-clean-go gen-clean-py  ## Remove generated API clients |  | ||||||
|  |  | ||||||
| gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescript into the authentik UI Application | gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescript into the authentik UI Application | ||||||
| 	docker run \ | 	docker run \ | ||||||
| @ -159,20 +155,6 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | |||||||
| 	cd ./${GEN_API_TS} && npm i | 	cd ./${GEN_API_TS} && npm i | ||||||
| 	\cp -rf ./${GEN_API_TS}/* web/node_modules/@goauthentik/api | 	\cp -rf ./${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||||
|  |  | ||||||
| gen-client-py: gen-clean-py ## Build and install the authentik API for Python |  | ||||||
| 	docker run \ |  | ||||||
| 		--rm -v ${PWD}:/local \ |  | ||||||
| 		--user ${UID}:${GID} \ |  | ||||||
| 		docker.io/openapitools/openapi-generator-cli:v7.4.0 generate \ |  | ||||||
| 		-i /local/schema.yml \ |  | ||||||
| 		-g python \ |  | ||||||
| 		-o /local/${GEN_API_PY} \ |  | ||||||
| 		-c /local/scripts/api-py-config.yaml \ |  | ||||||
| 		--additional-properties=packageVersion=${NPM_VERSION} \ |  | ||||||
| 		--git-repo-id authentik \ |  | ||||||
| 		--git-user-id goauthentik |  | ||||||
| 	pip install ./${GEN_API_PY} |  | ||||||
|  |  | ||||||
| gen-client-go: gen-clean-go  ## Build and install the authentik API for Golang | gen-client-go: gen-clean-go  ## Build and install the authentik API for Golang | ||||||
| 	mkdir -p ./${GEN_API_GO} ./${GEN_API_GO}/templates | 	mkdir -p ./${GEN_API_GO} ./${GEN_API_GO}/templates | ||||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./${GEN_API_GO}/config.yaml | 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./${GEN_API_GO}/config.yaml | ||||||
|  | |||||||
| @ -25,10 +25,10 @@ For bigger setups, there is a Helm Chart [here](https://github.com/goauthentik/h | |||||||
|  |  | ||||||
| ## Screenshots | ## Screenshots | ||||||
|  |  | ||||||
| | Light                                                       | Dark                                                       | | | Light                                                  | Dark                                                  | | ||||||
| | ----------------------------------------------------------- | ---------------------------------------------------------- | | | ------------------------------------------------------ | ----------------------------------------------------- | | ||||||
| |   |   | | |   |   | | ||||||
| |  |  | | |  |  | | ||||||
|  |  | ||||||
| ## Development | ## Development | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										20
									
								
								SECURITY.md
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								SECURITY.md
									
									
									
									
									
								
							| @ -18,10 +18,10 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni | |||||||
|  |  | ||||||
| (.x being the latest patch release for each version) | (.x being the latest patch release for each version) | ||||||
|  |  | ||||||
| | Version   | Supported | | | Version | Supported | | ||||||
| | --------- | --------- | | | --- | --- | | ||||||
| | 2023.10.x | ✅        | | | 2023.6.x | ✅ | | ||||||
| | 2024.2.x  | ✅        | | | 2023.8.x | ✅ | | ||||||
|  |  | ||||||
| ## Reporting a Vulnerability | ## Reporting a Vulnerability | ||||||
|  |  | ||||||
| @ -31,12 +31,12 @@ To report a vulnerability, send an email to [security@goauthentik.io](mailto:se | |||||||
|  |  | ||||||
| authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories: | authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories: | ||||||
|  |  | ||||||
| | Score      | Severity | | | Score | Severity | | ||||||
| | ---------- | -------- | | | --- | --- | | ||||||
| | 0.0        | None     | | | 0.0 | None | | ||||||
| | 0.1 – 3.9  | Low      | | | 0.1 – 3.9 | Low | | ||||||
| | 4.0 – 6.9  | Medium   | | | 4.0 – 6.9 | Medium | | ||||||
| | 7.0 – 8.9  | High     | | | 7.0 – 8.9 | High | | ||||||
| | 9.0 – 10.0 | Critical | | | 9.0 – 10.0 | Critical | | ||||||
|  |  | ||||||
| ## Disclosure process | ## Disclosure process | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from os import environ | from os import environ | ||||||
|  |  | ||||||
| __version__ = "2024.4.3" | __version__ = "2024.2.2" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -10,3 +10,26 @@ class AuthentikAPIConfig(AppConfig): | |||||||
|     label = "authentik_api" |     label = "authentik_api" | ||||||
|     mountpoint = "api/" |     mountpoint = "api/" | ||||||
|     verbose_name = "authentik API" |     verbose_name = "authentik API" | ||||||
|  |  | ||||||
|  |     def ready(self) -> None: | ||||||
|  |         from drf_spectacular.extensions import OpenApiAuthenticationExtension | ||||||
|  |  | ||||||
|  |         from authentik.api.authentication import TokenAuthentication | ||||||
|  |  | ||||||
|  |         # Class is defined here as it needs to be created early enough that drf-spectacular will | ||||||
|  |         # find it, but also won't cause any import issues | ||||||
|  |  | ||||||
|  |         class TokenSchema(OpenApiAuthenticationExtension): | ||||||
|  |             """Auth schema""" | ||||||
|  |  | ||||||
|  |             target_class = TokenAuthentication | ||||||
|  |             name = "authentik" | ||||||
|  |  | ||||||
|  |             def get_security_definition(self, auto_schema): | ||||||
|  |                 """Auth schema""" | ||||||
|  |                 return { | ||||||
|  |                     "type": "apiKey", | ||||||
|  |                     "in": "header", | ||||||
|  |                     "name": "Authorization", | ||||||
|  |                     "scheme": "bearer", | ||||||
|  |                 } | ||||||
|  | |||||||
| @ -4,7 +4,6 @@ from hmac import compare_digest | |||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from drf_spectacular.extensions import OpenApiAuthenticationExtension |  | ||||||
| from rest_framework.authentication import BaseAuthentication, get_authorization_header | from rest_framework.authentication import BaseAuthentication, get_authorization_header | ||||||
| from rest_framework.exceptions import AuthenticationFailed | from rest_framework.exceptions import AuthenticationFailed | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -103,14 +102,3 @@ class TokenAuthentication(BaseAuthentication): | |||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         return (user, None)  # pragma: no cover |         return (user, None)  # pragma: no cover | ||||||
|  |  | ||||||
|  |  | ||||||
| class TokenSchema(OpenApiAuthenticationExtension): |  | ||||||
|     """Auth schema""" |  | ||||||
|  |  | ||||||
|     target_class = TokenAuthentication |  | ||||||
|     name = "authentik" |  | ||||||
|  |  | ||||||
|     def get_security_definition(self, auto_schema): |  | ||||||
|         """Auth schema""" |  | ||||||
|         return {"type": "http", "scheme": "bearer"} |  | ||||||
|  | |||||||
| @ -12,7 +12,6 @@ from drf_spectacular.settings import spectacular_settings | |||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from rest_framework.settings import api_settings | from rest_framework.settings import api_settings | ||||||
|  |  | ||||||
| from authentik.api.apps import AuthentikAPIConfig |  | ||||||
| from authentik.api.pagination import PAGINATION_COMPONENT_NAME, PAGINATION_SCHEMA | from authentik.api.pagination import PAGINATION_COMPONENT_NAME, PAGINATION_SCHEMA | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,12 +101,3 @@ def postprocess_schema_responses(result, generator: SchemaGenerator, **kwargs): | |||||||
|             comp = result["components"]["schemas"][component] |             comp = result["components"]["schemas"][component] | ||||||
|             comp["additionalProperties"] = {} |             comp["additionalProperties"] = {} | ||||||
|     return result |     return result | ||||||
|  |  | ||||||
|  |  | ||||||
| def preprocess_schema_exclude_non_api(endpoints, **kwargs): |  | ||||||
|     """Filter out all API Views which are not mounted under /api""" |  | ||||||
|     return [ |  | ||||||
|         (path, path_regex, method, callback) |  | ||||||
|         for path, path_regex, method, callback in endpoints |  | ||||||
|         if path.startswith("/" + AuthentikAPIConfig.mountpoint) |  | ||||||
|     ] |  | ||||||
|  | |||||||
| @ -8,8 +8,6 @@ from django.apps import AppConfig | |||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.root.signals import startup |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ManagedAppConfig(AppConfig): | class ManagedAppConfig(AppConfig): | ||||||
|     """Basic reconciliation logic for apps""" |     """Basic reconciliation logic for apps""" | ||||||
| @ -25,12 +23,9 @@ class ManagedAppConfig(AppConfig): | |||||||
|  |  | ||||||
|     def ready(self) -> None: |     def ready(self) -> None: | ||||||
|         self.import_related() |         self.import_related() | ||||||
|         startup.connect(self._on_startup_callback, dispatch_uid=self.label) |  | ||||||
|         return super().ready() |  | ||||||
|  |  | ||||||
|     def _on_startup_callback(self, sender, **_): |  | ||||||
|         self._reconcile_global() |         self._reconcile_global() | ||||||
|         self._reconcile_tenant() |         self._reconcile_tenant() | ||||||
|  |         return super().ready() | ||||||
|  |  | ||||||
|     def import_related(self): |     def import_related(self): | ||||||
|         """Automatically import related modules which rely on just being imported |         """Automatically import related modules which rely on just being imported | ||||||
|  | |||||||
| @ -4,14 +4,12 @@ from json import dumps | |||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.core.management.base import BaseCommand, no_translations | from django.core.management.base import BaseCommand, no_translations | ||||||
| from django.db.models import Model, fields | from django.db.models import Model | ||||||
| from drf_jsonschema_serializer.convert import converter, field_to_converter | from drf_jsonschema_serializer.convert import field_to_converter | ||||||
| from rest_framework.fields import Field, JSONField, UUIDField | from rest_framework.fields import Field, JSONField, UUIDField | ||||||
| from rest_framework.relations import PrimaryKeyRelatedField |  | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik import __version__ |  | ||||||
| from authentik.blueprints.v1.common import BlueprintEntryDesiredState | from authentik.blueprints.v1.common import BlueprintEntryDesiredState | ||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, is_model_allowed | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, is_model_allowed | ||||||
| from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry | from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry | ||||||
| @ -20,23 +18,6 @@ from authentik.lib.models import SerializerModel | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @converter |  | ||||||
| class PrimaryKeyRelatedFieldConverter: |  | ||||||
|     """Custom primary key field converter which is aware of non-integer based PKs |  | ||||||
|  |  | ||||||
|     This is not an exhaustive fix for other non-int PKs, however in authentik we either |  | ||||||
|     use UUIDs or ints""" |  | ||||||
|  |  | ||||||
|     field_class = PrimaryKeyRelatedField |  | ||||||
|  |  | ||||||
|     def convert(self, field: PrimaryKeyRelatedField): |  | ||||||
|         model: Model = field.queryset.model |  | ||||||
|         pk_field = model._meta.pk |  | ||||||
|         if isinstance(pk_field, fields.UUIDField): |  | ||||||
|             return {"type": "string", "format": "uuid"} |  | ||||||
|         return {"type": "integer"} |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Command(BaseCommand): | class Command(BaseCommand): | ||||||
|     """Generate JSON Schema for blueprints""" |     """Generate JSON Schema for blueprints""" | ||||||
|  |  | ||||||
| @ -48,7 +29,7 @@ class Command(BaseCommand): | |||||||
|             "$schema": "http://json-schema.org/draft-07/schema", |             "$schema": "http://json-schema.org/draft-07/schema", | ||||||
|             "$id": "https://goauthentik.io/blueprints/schema.json", |             "$id": "https://goauthentik.io/blueprints/schema.json", | ||||||
|             "type": "object", |             "type": "object", | ||||||
|             "title": f"authentik {__version__} Blueprint schema", |             "title": "authentik Blueprint schema", | ||||||
|             "required": ["version", "entries"], |             "required": ["version", "entries"], | ||||||
|             "properties": { |             "properties": { | ||||||
|                 "version": { |                 "version": { | ||||||
|  | |||||||
| @ -39,7 +39,7 @@ def reconcile_app(app_name: str): | |||||||
|         def wrapper(*args, **kwargs): |         def wrapper(*args, **kwargs): | ||||||
|             config = apps.get_app_config(app_name) |             config = apps.get_app_config(app_name) | ||||||
|             if isinstance(config, ManagedAppConfig): |             if isinstance(config, ManagedAppConfig): | ||||||
|                 config._on_startup_callback(None) |                 config.ready() | ||||||
|             return func(*args, **kwargs) |             return func(*args, **kwargs) | ||||||
|  |  | ||||||
|         return wrapper |         return wrapper | ||||||
|  | |||||||
| @ -556,11 +556,7 @@ class BlueprintDumper(SafeDumper): | |||||||
|  |  | ||||||
|             def factory(items): |             def factory(items): | ||||||
|                 final_dict = dict(items) |                 final_dict = dict(items) | ||||||
|                 # Remove internal state variables |  | ||||||
|                 final_dict.pop("_state", None) |                 final_dict.pop("_state", None) | ||||||
|                 # Future-proof to only remove the ID if we don't set a value |  | ||||||
|                 if "id" in final_dict and final_dict.get("id") is None: |  | ||||||
|                     final_dict.pop("id") |  | ||||||
|                 return final_dict |                 return final_dict | ||||||
|  |  | ||||||
|             data = asdict(data, dict_factory=factory) |             data = asdict(data, dict_factory=factory) | ||||||
|  | |||||||
| @ -19,6 +19,8 @@ from guardian.models import UserObjectPermission | |||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.serializers import BaseSerializer, Serializer | from rest_framework.serializers import BaseSerializer, Serializer | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  | from structlog.testing import capture_logs | ||||||
|  | from structlog.types import EventDict | ||||||
| from yaml import load | from yaml import load | ||||||
|  |  | ||||||
| from authentik.blueprints.v1.common import ( | from authentik.blueprints.v1.common import ( | ||||||
| @ -40,7 +42,6 @@ from authentik.core.models import ( | |||||||
| from authentik.enterprise.license import LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
| from authentik.enterprise.models import LicenseUsage | from authentik.enterprise.models import LicenseUsage | ||||||
| from authentik.enterprise.providers.rac.models import ConnectionToken | from authentik.enterprise.providers.rac.models import ConnectionToken | ||||||
| from authentik.events.logs import LogEvent, capture_logs |  | ||||||
| from authentik.events.models import SystemTask | from authentik.events.models import SystemTask | ||||||
| from authentik.events.utils import cleanse_dict | from authentik.events.utils import cleanse_dict | ||||||
| from authentik.flows.models import FlowToken, Stage | from authentik.flows.models import FlowToken, Stage | ||||||
| @ -51,8 +52,6 @@ from authentik.policies.models import Policy, PolicyBindingModel | |||||||
| from authentik.policies.reputation.models import Reputation | from authentik.policies.reputation.models import Reputation | ||||||
| from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken | from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken | ||||||
| from authentik.providers.scim.models import SCIMGroup, SCIMUser | from authentik.providers.scim.models import SCIMGroup, SCIMUser | ||||||
| from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser |  | ||||||
| from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType |  | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| # Context set when the serializer is created in a blueprint context | # Context set when the serializer is created in a blueprint context | ||||||
| @ -97,9 +96,6 @@ def excluded_models() -> list[type[Model]]: | |||||||
|         AccessToken, |         AccessToken, | ||||||
|         RefreshToken, |         RefreshToken, | ||||||
|         Reputation, |         Reputation, | ||||||
|         WebAuthnDeviceType, |  | ||||||
|         SCIMSourceUser, |  | ||||||
|         SCIMSourceGroup, |  | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -165,7 +161,7 @@ class Importer: | |||||||
|  |  | ||||||
|         def updater(value) -> Any: |         def updater(value) -> Any: | ||||||
|             if value in self.__pk_map: |             if value in self.__pk_map: | ||||||
|                 self.logger.debug("Updating reference in entry", value=value) |                 self.logger.debug("updating reference in entry", value=value) | ||||||
|                 return self.__pk_map[value] |                 return self.__pk_map[value] | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
| @ -254,7 +250,7 @@ class Importer: | |||||||
|         model_instance = existing_models.first() |         model_instance = existing_models.first() | ||||||
|         if not isinstance(model(), BaseMetaModel) and model_instance: |         if not isinstance(model(), BaseMetaModel) and model_instance: | ||||||
|             self.logger.debug( |             self.logger.debug( | ||||||
|                 "Initialise serializer with instance", |                 "initialise serializer with instance", | ||||||
|                 model=model, |                 model=model, | ||||||
|                 instance=model_instance, |                 instance=model_instance, | ||||||
|                 pk=model_instance.pk, |                 pk=model_instance.pk, | ||||||
| @ -264,14 +260,14 @@ class Importer: | |||||||
|         elif model_instance and entry.state == BlueprintEntryDesiredState.MUST_CREATED: |         elif model_instance and entry.state == BlueprintEntryDesiredState.MUST_CREATED: | ||||||
|             raise EntryInvalidError.from_entry( |             raise EntryInvalidError.from_entry( | ||||||
|                 ( |                 ( | ||||||
|                     f"State is set to {BlueprintEntryDesiredState.MUST_CREATED} " |                     f"state is set to {BlueprintEntryDesiredState.MUST_CREATED} " | ||||||
|                     "and object exists already", |                     "and object exists already", | ||||||
|                 ), |                 ), | ||||||
|                 entry, |                 entry, | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             self.logger.debug( |             self.logger.debug( | ||||||
|                 "Initialised new serializer instance", |                 "initialised new serializer instance", | ||||||
|                 model=model, |                 model=model, | ||||||
|                 **cleanse_dict(updated_identifiers), |                 **cleanse_dict(updated_identifiers), | ||||||
|             ) |             ) | ||||||
| @ -328,7 +324,7 @@ class Importer: | |||||||
|                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) |                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||||
|             except LookupError: |             except LookupError: | ||||||
|                 self.logger.warning( |                 self.logger.warning( | ||||||
|                     "App or Model does not exist", app=model_app_label, model=model_name |                     "app or model does not exist", app=model_app_label, model=model_name | ||||||
|                 ) |                 ) | ||||||
|                 return False |                 return False | ||||||
|             # Validate each single entry |             # Validate each single entry | ||||||
| @ -340,7 +336,7 @@ class Importer: | |||||||
|                 if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT: |                 if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT: | ||||||
|                     serializer = exc.serializer |                     serializer = exc.serializer | ||||||
|                 else: |                 else: | ||||||
|                     self.logger.warning(f"Entry invalid: {exc}", entry=entry, error=exc) |                     self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) | ||||||
|                     if raise_errors: |                     if raise_errors: | ||||||
|                         raise exc |                         raise exc | ||||||
|                     return False |                     return False | ||||||
| @ -360,14 +356,14 @@ class Importer: | |||||||
|                     and state == BlueprintEntryDesiredState.CREATED |                     and state == BlueprintEntryDesiredState.CREATED | ||||||
|                 ): |                 ): | ||||||
|                     self.logger.debug( |                     self.logger.debug( | ||||||
|                         "Instance exists, skipping", |                         "instance exists, skipping", | ||||||
|                         model=model, |                         model=model, | ||||||
|                         instance=instance, |                         instance=instance, | ||||||
|                         pk=instance.pk, |                         pk=instance.pk, | ||||||
|                     ) |                     ) | ||||||
|                 else: |                 else: | ||||||
|                     instance = serializer.save() |                     instance = serializer.save() | ||||||
|                     self.logger.debug("Updated model", model=instance) |                     self.logger.debug("updated model", model=instance) | ||||||
|                 if "pk" in entry.identifiers: |                 if "pk" in entry.identifiers: | ||||||
|                     self.__pk_map[entry.identifiers["pk"]] = instance.pk |                     self.__pk_map[entry.identifiers["pk"]] = instance.pk | ||||||
|                 entry._state = BlueprintEntryState(instance) |                 entry._state = BlueprintEntryState(instance) | ||||||
| @ -375,12 +371,12 @@ class Importer: | |||||||
|                 instance: Model | None = serializer.instance |                 instance: Model | None = serializer.instance | ||||||
|                 if instance.pk: |                 if instance.pk: | ||||||
|                     instance.delete() |                     instance.delete() | ||||||
|                     self.logger.debug("Deleted model", mode=instance) |                     self.logger.debug("deleted model", mode=instance) | ||||||
|                     continue |                     continue | ||||||
|                 self.logger.debug("Entry to delete with no instance, skipping") |                 self.logger.debug("entry to delete with no instance, skipping") | ||||||
|         return True |         return True | ||||||
|  |  | ||||||
|     def validate(self, raise_validation_errors=False) -> tuple[bool, list[LogEvent]]: |     def validate(self, raise_validation_errors=False) -> tuple[bool, list[EventDict]]: | ||||||
|         """Validate loaded blueprint export, ensure all models are allowed |         """Validate loaded blueprint export, ensure all models are allowed | ||||||
|         and serializers have no errors""" |         and serializers have no errors""" | ||||||
|         self.logger.debug("Starting blueprint import validation") |         self.logger.debug("Starting blueprint import validation") | ||||||
| @ -394,7 +390,9 @@ class Importer: | |||||||
|         ): |         ): | ||||||
|             successful = self._apply_models(raise_errors=raise_validation_errors) |             successful = self._apply_models(raise_errors=raise_validation_errors) | ||||||
|             if not successful: |             if not successful: | ||||||
|                 self.logger.warning("Blueprint validation failed") |                 self.logger.debug("Blueprint validation failed") | ||||||
|  |         for log in logs: | ||||||
|  |             getattr(self.logger, log.get("log_level"))(**log) | ||||||
|         self.logger.debug("Finished blueprint import validation") |         self.logger.debug("Finished blueprint import validation") | ||||||
|         self._import = orig_import |         self._import = orig_import | ||||||
|         return successful, logs |         return successful, logs | ||||||
|  | |||||||
| @ -30,7 +30,6 @@ from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, E | |||||||
| from authentik.blueprints.v1.importer import Importer | from authentik.blueprints.v1.importer import Importer | ||||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | ||||||
| from authentik.blueprints.v1.oci import OCI_PREFIX | from authentik.blueprints.v1.oci import OCI_PREFIX | ||||||
| from authentik.events.logs import capture_logs |  | ||||||
| from authentik.events.models import TaskStatus | from authentik.events.models import TaskStatus | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
| from authentik.events.utils import sanitize_dict | from authentik.events.utils import sanitize_dict | ||||||
| @ -212,15 +211,14 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | |||||||
|         if not valid: |         if not valid: | ||||||
|             instance.status = BlueprintInstanceStatus.ERROR |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
|             instance.save() |             instance.save() | ||||||
|             self.set_status(TaskStatus.ERROR, *logs) |             self.set_status(TaskStatus.ERROR, *[x["event"] for x in logs]) | ||||||
|  |             return | ||||||
|  |         applied = importer.apply() | ||||||
|  |         if not applied: | ||||||
|  |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
|  |             instance.save() | ||||||
|  |             self.set_status(TaskStatus.ERROR, "Failed to apply") | ||||||
|             return |             return | ||||||
|         with capture_logs() as logs: |  | ||||||
|             applied = importer.apply() |  | ||||||
|             if not applied: |  | ||||||
|                 instance.status = BlueprintInstanceStatus.ERROR |  | ||||||
|                 instance.save() |  | ||||||
|                 self.set_status(TaskStatus.ERROR, *logs) |  | ||||||
|                 return |  | ||||||
|         instance.status = BlueprintInstanceStatus.SUCCESSFUL |         instance.status = BlueprintInstanceStatus.SUCCESSFUL | ||||||
|         instance.last_applied_hash = file_hash |         instance.last_applied_hash = file_hash | ||||||
|         instance.last_applied = now() |         instance.last_applied = now() | ||||||
|  | |||||||
| @ -46,6 +46,7 @@ class BrandSerializer(ModelSerializer): | |||||||
|         fields = [ |         fields = [ | ||||||
|             "brand_uuid", |             "brand_uuid", | ||||||
|             "domain", |             "domain", | ||||||
|  |             "origin", | ||||||
|             "default", |             "default", | ||||||
|             "branding_title", |             "branding_title", | ||||||
|             "branding_logo", |             "branding_logo", | ||||||
| @ -56,6 +57,7 @@ class BrandSerializer(ModelSerializer): | |||||||
|             "flow_unenrollment", |             "flow_unenrollment", | ||||||
|             "flow_user_settings", |             "flow_user_settings", | ||||||
|             "flow_device_code", |             "flow_device_code", | ||||||
|  |             "default_application", | ||||||
|             "web_certificate", |             "web_certificate", | ||||||
|             "attributes", |             "attributes", | ||||||
|         ] |         ] | ||||||
|  | |||||||
| @ -1,12 +1,17 @@ | |||||||
| """Inject brand into current request""" | """Inject brand into current request""" | ||||||
|  |  | ||||||
| from collections.abc import Callable | from collections.abc import Callable | ||||||
|  | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
| from django.http.request import HttpRequest | from django.http.request import HttpRequest | ||||||
| from django.http.response import HttpResponse | from django.http.response import HttpResponse | ||||||
| from django.utils.translation import activate | from django.utils.translation import activate | ||||||
|  |  | ||||||
| from authentik.brands.utils import get_brand_for_request | from authentik.brands.utils import get_brand_for_request | ||||||
|  | from authentik.lib.config import CONFIG | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from authentik.brands.models import Brand | ||||||
|  |  | ||||||
|  |  | ||||||
| class BrandMiddleware: | class BrandMiddleware: | ||||||
| @ -25,3 +30,41 @@ class BrandMiddleware: | |||||||
|             if locale != "": |             if locale != "": | ||||||
|                 activate(locale) |                 activate(locale) | ||||||
|         return self.get_response(request) |         return self.get_response(request) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BrandHeaderMiddleware: | ||||||
|  |     """Add headers from currently active brand""" | ||||||
|  |  | ||||||
|  |     get_response: Callable[[HttpRequest], HttpResponse] | ||||||
|  |     default_csp_elements: dict[str, list[str]] = {} | ||||||
|  |  | ||||||
|  |     def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): | ||||||
|  |         self.get_response = get_response | ||||||
|  |         self.default_csp_elements = { | ||||||
|  |             "style-src": ["'self'", "'unsafe-inline'"],  # Required due to Lit/ShadowDOM | ||||||
|  |             "script-src": ["'self'", "'unsafe-inline'"],  # Required for generated scripts | ||||||
|  |             "img-src": ["https:", "http:", "data:"], | ||||||
|  |             "default-src": ["'self'"], | ||||||
|  |             "object-src": ["'none'"], | ||||||
|  |             "connect-src": ["'self'"], | ||||||
|  |         } | ||||||
|  |         if CONFIG.get_bool("error_reporting.enabled"): | ||||||
|  |             self.default_csp_elements["connect-src"].append( | ||||||
|  |                 # Required for sentry (TODO: Dynamic) | ||||||
|  |                 "https://authentik.error-reporting.a7k.io" | ||||||
|  |             ) | ||||||
|  |             if CONFIG.get_bool("debug"): | ||||||
|  |                 # Also allow spotlight sidecar connection | ||||||
|  |                 self.default_csp_elements["connect-src"].append("http://localhost:8969") | ||||||
|  |  | ||||||
|  |     def get_csp(self, request: HttpRequest) -> str: | ||||||
|  |         brand: "Brand" = request.brand | ||||||
|  |         elements = self.default_csp_elements.copy() | ||||||
|  |         if brand.origin != "": | ||||||
|  |             elements["frame-ancestors"] = [brand.origin] | ||||||
|  |         return ";".join(f"{attr} {" ".join(value)}" for attr, value in elements.items()) | ||||||
|  |  | ||||||
|  |     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||||
|  |         response = self.get_response(request) | ||||||
|  |         response.headers["Content-Security-Policy"] = self.get_csp(request) | ||||||
|  |         return response | ||||||
|  | |||||||
| @ -1,21 +0,0 @@ | |||||||
| # Generated by Django 5.0.4 on 2024-04-18 18:56 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_brands", "0005_tenantuuid_to_branduuid"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="brand", |  | ||||||
|             index=models.Index(fields=["domain"], name="authentik_b_domain_b9b24a_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="brand", |  | ||||||
|             index=models.Index(fields=["default"], name="authentik_b_default_3ccf12_idx"), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -0,0 +1,26 @@ | |||||||
|  | # Generated by Django 5.0.3 on 2024-03-21 15:42 | ||||||
|  |  | ||||||
|  | import django.db.models.deletion | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_brands", "0005_tenantuuid_to_branduuid"), | ||||||
|  |         ("authentik_core", "0033_alter_user_options"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="brand", | ||||||
|  |             name="default_application", | ||||||
|  |             field=models.ForeignKey( | ||||||
|  |                 default=None, | ||||||
|  |                 help_text="When set, external users will be redirected to this application after authenticating.", | ||||||
|  |                 null=True, | ||||||
|  |                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||||
|  |                 to="authentik_core.application", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
							
								
								
									
										21
									
								
								authentik/brands/migrations/0007_brand_origin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								authentik/brands/migrations/0007_brand_origin.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | # Generated by Django 5.0.3 on 2024-03-26 14:17 | ||||||
|  |  | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_brands", "0006_brand_default_application"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="brand", | ||||||
|  |             name="origin", | ||||||
|  |             field=models.TextField( | ||||||
|  |                 blank=True, | ||||||
|  |                 help_text="Origin domain that activates this brand. Can be left empty to not allow any origins.", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -23,6 +23,12 @@ class Brand(SerializerModel): | |||||||
|             "Domain that activates this brand. Can be a superset, i.e. `a.b` for `aa.b` and `ba.b`" |             "Domain that activates this brand. Can be a superset, i.e. `a.b` for `aa.b` and `ba.b`" | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|  |     origin = models.TextField( | ||||||
|  |         help_text=_( | ||||||
|  |             "Origin domain that activates this brand. Can be left empty to not allow any origins." | ||||||
|  |         ), | ||||||
|  |         blank=True, | ||||||
|  |     ) | ||||||
|     default = models.BooleanField( |     default = models.BooleanField( | ||||||
|         default=False, |         default=False, | ||||||
|     ) |     ) | ||||||
| @ -51,6 +57,16 @@ class Brand(SerializerModel): | |||||||
|         Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code" |         Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code" | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     default_application = models.ForeignKey( | ||||||
|  |         "authentik_core.Application", | ||||||
|  |         null=True, | ||||||
|  |         default=None, | ||||||
|  |         on_delete=models.SET_DEFAULT, | ||||||
|  |         help_text=_( | ||||||
|  |             "When set, external users will be redirected to this application after authenticating." | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     web_certificate = models.ForeignKey( |     web_certificate = models.ForeignKey( | ||||||
|         CertificateKeyPair, |         CertificateKeyPair, | ||||||
|         null=True, |         null=True, | ||||||
| @ -84,7 +100,3 @@ class Brand(SerializerModel): | |||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Brand") |         verbose_name = _("Brand") | ||||||
|         verbose_name_plural = _("Brands") |         verbose_name_plural = _("Brands") | ||||||
|         indexes = [ |  | ||||||
|             models.Index(fields=["domain"]), |  | ||||||
|             models.Index(fields=["default"]), |  | ||||||
|         ] |  | ||||||
|  | |||||||
| @ -1,11 +1,15 @@ | |||||||
| """Brand utilities""" | """Brand utilities""" | ||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  | from urllib.parse import urlparse | ||||||
|  |  | ||||||
| from django.db.models import F, Q | from django.db.models import F, Q | ||||||
| from django.db.models import Value as V | from django.db.models import Value as V | ||||||
|  | from django.http import HttpResponse | ||||||
| from django.http.request import HttpRequest | from django.http.request import HttpRequest | ||||||
|  | from django.utils.cache import patch_vary_headers | ||||||
| from sentry_sdk.hub import Hub | from sentry_sdk.hub import Hub | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik import get_full_version | from authentik import get_full_version | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| @ -13,13 +17,17 @@ from authentik.tenants.models import Tenant | |||||||
|  |  | ||||||
| _q_default = Q(default=True) | _q_default = Q(default=True) | ||||||
| DEFAULT_BRAND = Brand(domain="fallback") | DEFAULT_BRAND = Brand(domain="fallback") | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_brand_for_request(request: HttpRequest) -> Brand: | def get_brand_for_request(request: HttpRequest) -> Brand: | ||||||
|     """Get brand object for current request""" |     """Get brand object for current request""" | ||||||
|  |     query = Q(host_domain__iendswith=F("domain")) | ||||||
|  |     if "Origin" in request.headers: | ||||||
|  |         query &= Q(Q(origin=request.headers.get("Origin", "")) | Q(origin="")) | ||||||
|     db_brands = ( |     db_brands = ( | ||||||
|         Brand.objects.annotate(host_domain=V(request.get_host())) |         Brand.objects.annotate(host_domain=V(request.get_host())) | ||||||
|         .filter(Q(host_domain__iendswith=F("domain")) | _q_default) |         .filter(Q(query) | _q_default) | ||||||
|         .order_by("default") |         .order_by("default") | ||||||
|     ) |     ) | ||||||
|     brands = list(db_brands.all()) |     brands = list(db_brands.all()) | ||||||
| @ -42,3 +50,46 @@ def context_processor(request: HttpRequest) -> dict[str, Any]: | |||||||
|         "sentry_trace": trace, |         "sentry_trace": trace, | ||||||
|         "version": get_full_version(), |         "version": get_full_version(), | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: str): | ||||||
|  |     """Add headers to permit CORS requests from allowed_origins, with or without credentials, | ||||||
|  |     with any headers.""" | ||||||
|  |     origin = request.META.get("HTTP_ORIGIN") | ||||||
|  |     if not origin: | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |     # OPTIONS requests don't have an authorization header -> hence | ||||||
|  |     # we can't extract the provider this request is for | ||||||
|  |     # so for options requests we allow the calling origin without checking | ||||||
|  |     allowed = request.method == "OPTIONS" | ||||||
|  |     received_origin = urlparse(origin) | ||||||
|  |     for allowed_origin in allowed_origins: | ||||||
|  |         url = urlparse(allowed_origin) | ||||||
|  |         if ( | ||||||
|  |             received_origin.scheme == url.scheme | ||||||
|  |             and received_origin.hostname == url.hostname | ||||||
|  |             and received_origin.port == url.port | ||||||
|  |         ): | ||||||
|  |             allowed = True | ||||||
|  |     if not allowed: | ||||||
|  |         LOGGER.warning( | ||||||
|  |             "CORS: Origin is not an allowed origin", | ||||||
|  |             requested=received_origin, | ||||||
|  |             allowed=allowed_origins, | ||||||
|  |         ) | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |     # From the CORS spec: The string "*" cannot be used for a resource that supports credentials. | ||||||
|  |     response["Access-Control-Allow-Origin"] = origin | ||||||
|  |     patch_vary_headers(response, ["Origin"]) | ||||||
|  |     response["Access-Control-Allow-Credentials"] = "true" | ||||||
|  |  | ||||||
|  |     if request.method == "OPTIONS": | ||||||
|  |         if "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" in request.META: | ||||||
|  |             response["Access-Control-Allow-Headers"] = request.META[ | ||||||
|  |                 "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" | ||||||
|  |             ] | ||||||
|  |         response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" | ||||||
|  |  | ||||||
|  |     return response | ||||||
|  | |||||||
| @ -20,15 +20,15 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ModelSerializer | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  | from structlog.testing import capture_logs | ||||||
|  |  | ||||||
| from authentik.admin.api.metrics import CoordinateSerializer | from authentik.admin.api.metrics import CoordinateSerializer | ||||||
| from authentik.api.pagination import Pagination |  | ||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||||
| from authentik.core.api.providers import ProviderSerializer | from authentik.core.api.providers import ProviderSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.models import Application, User | from authentik.core.models import Application, User | ||||||
| from authentik.events.logs import LogEventSerializer, capture_logs |  | ||||||
| from authentik.events.models import EventAction | from authentik.events.models import EventAction | ||||||
|  | from authentik.events.utils import sanitize_dict | ||||||
| from authentik.lib.utils.file import ( | from authentik.lib.utils.file import ( | ||||||
|     FilePathSerializer, |     FilePathSerializer, | ||||||
|     FileUploadSerializer, |     FileUploadSerializer, | ||||||
| @ -44,12 +44,9 @@ from authentik.rbac.filters import ObjectFilter | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| def user_app_cache_key(user_pk: str, page_number: int | None = None) -> str: | def user_app_cache_key(user_pk: str) -> str: | ||||||
|     """Cache key where application list for user is saved""" |     """Cache key where application list for user is saved""" | ||||||
|     key = f"{CACHE_PREFIX}/app_access/{user_pk}" |     return f"{CACHE_PREFIX}/app_access/{user_pk}" | ||||||
|     if page_number: |  | ||||||
|         key += f"/{page_number}" |  | ||||||
|     return key |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ApplicationSerializer(ModelSerializer): | class ApplicationSerializer(ModelSerializer): | ||||||
| @ -185,9 +182,9 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         if request.user.is_superuser: |         if request.user.is_superuser: | ||||||
|             log_messages = [] |             log_messages = [] | ||||||
|             for log in logs: |             for log in logs: | ||||||
|                 if log.attributes.get("process", "") == "PolicyProcess": |                 if log.get("process", "") == "PolicyProcess": | ||||||
|                     continue |                     continue | ||||||
|                 log_messages.append(LogEventSerializer(log).data) |                 log_messages.append(sanitize_dict(log)) | ||||||
|             result.log_messages = log_messages |             result.log_messages = log_messages | ||||||
|             response = PolicyTestResultSerializer(result) |             response = PolicyTestResultSerializer(result) | ||||||
|         return Response(response.data) |         return Response(response.data) | ||||||
| @ -217,8 +214,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|             return super().list(request) |             return super().list(request) | ||||||
|  |  | ||||||
|         queryset = self._filter_queryset_for_list(self.get_queryset()) |         queryset = self._filter_queryset_for_list(self.get_queryset()) | ||||||
|         paginator: Pagination = self.paginator |         paginated_apps = self.paginate_queryset(queryset) | ||||||
|         paginated_apps = paginator.paginate_queryset(queryset, request) |  | ||||||
|  |  | ||||||
|         if "for_user" in request.query_params: |         if "for_user" in request.query_params: | ||||||
|             try: |             try: | ||||||
| @ -240,14 +236,12 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         if not should_cache: |         if not should_cache: | ||||||
|             allowed_applications = self._get_allowed_applications(paginated_apps) |             allowed_applications = self._get_allowed_applications(paginated_apps) | ||||||
|         if should_cache: |         if should_cache: | ||||||
|             allowed_applications = cache.get( |             allowed_applications = cache.get(user_app_cache_key(self.request.user.pk)) | ||||||
|                 user_app_cache_key(self.request.user.pk, paginator.page.number) |  | ||||||
|             ) |  | ||||||
|             if not allowed_applications: |             if not allowed_applications: | ||||||
|                 LOGGER.debug("Caching allowed application list", page=paginator.page.number) |                 LOGGER.debug("Caching allowed application list") | ||||||
|                 allowed_applications = self._get_allowed_applications(paginated_apps) |                 allowed_applications = self._get_allowed_applications(paginated_apps) | ||||||
|                 cache.set( |                 cache.set( | ||||||
|                     user_app_cache_key(self.request.user.pk, paginator.page.number), |                     user_app_cache_key(self.request.user.pk), | ||||||
|                     allowed_applications, |                     allowed_applications, | ||||||
|                     timeout=86400, |                     timeout=86400, | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -5,15 +5,10 @@ from json import loads | |||||||
| from django.http import Http404 | from django.http import Http404 | ||||||
| from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | ||||||
| from django_filters.filterset import FilterSet | from django_filters.filterset import FilterSet | ||||||
| from drf_spectacular.utils import ( | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
|     OpenApiParameter, |  | ||||||
|     OpenApiResponse, |  | ||||||
|     extend_schema, |  | ||||||
|     extend_schema_field, |  | ||||||
| ) |  | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import CharField, IntegerField, SerializerMethodField | from rest_framework.fields import CharField, IntegerField | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError | from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError | ||||||
| @ -50,7 +45,9 @@ class GroupSerializer(ModelSerializer): | |||||||
|     """Group Serializer""" |     """Group Serializer""" | ||||||
|  |  | ||||||
|     attributes = JSONDictField(required=False) |     attributes = JSONDictField(required=False) | ||||||
|     users_obj = SerializerMethodField(allow_null=True) |     users_obj = ListSerializer( | ||||||
|  |         child=GroupMemberSerializer(), read_only=True, source="users", required=False | ||||||
|  |     ) | ||||||
|     roles_obj = ListSerializer( |     roles_obj = ListSerializer( | ||||||
|         child=RoleSerializer(), |         child=RoleSerializer(), | ||||||
|         read_only=True, |         read_only=True, | ||||||
| @ -61,19 +58,6 @@ class GroupSerializer(ModelSerializer): | |||||||
|  |  | ||||||
|     num_pk = IntegerField(read_only=True) |     num_pk = IntegerField(read_only=True) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def _should_include_users(self) -> bool: |  | ||||||
|         request: Request = self.context.get("request", None) |  | ||||||
|         if not request: |  | ||||||
|             return True |  | ||||||
|         return str(request.query_params.get("include_users", "true")).lower() == "true" |  | ||||||
|  |  | ||||||
|     @extend_schema_field(GroupMemberSerializer(many=True)) |  | ||||||
|     def get_users_obj(self, instance: Group) -> list[GroupMemberSerializer] | None: |  | ||||||
|         if not self._should_include_users: |  | ||||||
|             return None |  | ||||||
|         return GroupMemberSerializer(instance.users, many=True).data |  | ||||||
|  |  | ||||||
|     def validate_parent(self, parent: Group | None): |     def validate_parent(self, parent: Group | None): | ||||||
|         """Validate group parent (if set), ensuring the parent isn't itself""" |         """Validate group parent (if set), ensuring the parent isn't itself""" | ||||||
|         if not self.instance or not parent: |         if not self.instance or not parent: | ||||||
| @ -146,35 +130,22 @@ class GroupFilter(FilterSet): | |||||||
|         fields = ["name", "is_superuser", "members_by_pk", "attributes", "members_by_username"] |         fields = ["name", "is_superuser", "members_by_pk", "attributes", "members_by_username"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserAccountSerializer(PassiveSerializer): | ||||||
|  |     """Account adding/removing operations""" | ||||||
|  |  | ||||||
|  |     pk = IntegerField(required=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class GroupViewSet(UsedByMixin, ModelViewSet): | class GroupViewSet(UsedByMixin, ModelViewSet): | ||||||
|     """Group Viewset""" |     """Group Viewset""" | ||||||
|  |  | ||||||
|     class UserAccountSerializer(PassiveSerializer): |     queryset = Group.objects.all().select_related("parent").prefetch_related("users") | ||||||
|         """Account adding/removing operations""" |  | ||||||
|  |  | ||||||
|         pk = IntegerField(required=True) |  | ||||||
|  |  | ||||||
|     queryset = Group.objects.none() |  | ||||||
|     serializer_class = GroupSerializer |     serializer_class = GroupSerializer | ||||||
|     search_fields = ["name", "is_superuser"] |     search_fields = ["name", "is_superuser"] | ||||||
|     filterset_class = GroupFilter |     filterset_class = GroupFilter | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|  |  | ||||||
|     def get_queryset(self): |     @permission_required(None, ["authentik_core.add_user"]) | ||||||
|         base_qs = Group.objects.all().select_related("parent").prefetch_related("roles") |  | ||||||
|         if self.serializer_class(context={"request": self.request})._should_include_users: |  | ||||||
|             base_qs = base_qs.prefetch_related("users") |  | ||||||
|         return base_qs |  | ||||||
|  |  | ||||||
|     @extend_schema( |  | ||||||
|         parameters=[ |  | ||||||
|             OpenApiParameter("include_users", bool, default=True), |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
|     def list(self, request, *args, **kwargs): |  | ||||||
|         return super().list(request, *args, **kwargs) |  | ||||||
|  |  | ||||||
|     @permission_required("authentik_core.add_user_to_group") |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         request=UserAccountSerializer, |         request=UserAccountSerializer, | ||||||
|         responses={ |         responses={ | ||||||
| @ -182,13 +153,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet): | |||||||
|             404: OpenApiResponse(description="User not found"), |             404: OpenApiResponse(description="User not found"), | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action( |     @action(detail=True, methods=["POST"], pagination_class=None, filter_backends=[]) | ||||||
|         detail=True, |  | ||||||
|         methods=["POST"], |  | ||||||
|         pagination_class=None, |  | ||||||
|         filter_backends=[], |  | ||||||
|         permission_classes=[], |  | ||||||
|     ) |  | ||||||
|     def add_user(self, request: Request, pk: str) -> Response: |     def add_user(self, request: Request, pk: str) -> Response: | ||||||
|         """Add user to group""" |         """Add user to group""" | ||||||
|         group: Group = self.get_object() |         group: Group = self.get_object() | ||||||
| @ -204,7 +169,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet): | |||||||
|         group.users.add(user) |         group.users.add(user) | ||||||
|         return Response(status=204) |         return Response(status=204) | ||||||
|  |  | ||||||
|     @permission_required("authentik_core.remove_user_from_group") |     @permission_required(None, ["authentik_core.add_user"]) | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         request=UserAccountSerializer, |         request=UserAccountSerializer, | ||||||
|         responses={ |         responses={ | ||||||
| @ -212,13 +177,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet): | |||||||
|             404: OpenApiResponse(description="User not found"), |             404: OpenApiResponse(description="User not found"), | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action( |     @action(detail=True, methods=["POST"], pagination_class=None, filter_backends=[]) | ||||||
|         detail=True, |  | ||||||
|         methods=["POST"], |  | ||||||
|         pagination_class=None, |  | ||||||
|         filter_backends=[], |  | ||||||
|         permission_classes=[], |  | ||||||
|     ) |  | ||||||
|     def remove_user(self, request: Request, pk: str) -> Response: |     def remove_user(self, request: Request, pk: str) -> Response: | ||||||
|         """Add user to group""" |         """Add user to group""" | ||||||
|         group: Group = self.get_object() |         group: Group = self.get_object() | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.utils.timezone import now |  | ||||||
| from django_filters.rest_framework import DjangoFilterBackend | from django_filters.rest_framework import DjangoFilterBackend | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer | from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer | ||||||
| from guardian.shortcuts import assign_perm, get_anonymous_user | from guardian.shortcuts import assign_perm, get_anonymous_user | ||||||
| @ -21,17 +20,9 @@ from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | |||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.users import UserSerializer | from authentik.core.api.users import UserSerializer | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.core.models import ( | from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | ||||||
|     USER_ATTRIBUTE_TOKEN_EXPIRING, |  | ||||||
|     USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, |  | ||||||
|     Token, |  | ||||||
|     TokenIntents, |  | ||||||
|     User, |  | ||||||
|     default_token_duration, |  | ||||||
| ) |  | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.utils import model_to_dict | from authentik.events.utils import model_to_dict | ||||||
| from authentik.lib.utils.time import timedelta_from_string |  | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -45,13 +36,6 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | |||||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: |         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||||
|             self.fields["key"] = CharField(required=False) |             self.fields["key"] = CharField(required=False) | ||||||
|  |  | ||||||
|     def validate_user(self, user: User): |  | ||||||
|         """Ensure user of token cannot be changed""" |  | ||||||
|         if self.instance and self.instance.user_id: |  | ||||||
|             if user.pk != self.instance.user_id: |  | ||||||
|                 raise ValidationError("User cannot be changed") |  | ||||||
|         return user |  | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: |     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: | ||||||
|         """Ensure only API or App password tokens are created.""" |         """Ensure only API or App password tokens are created.""" | ||||||
|         request: Request = self.context.get("request") |         request: Request = self.context.get("request") | ||||||
| @ -65,32 +49,6 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | |||||||
|         attrs.setdefault("intent", TokenIntents.INTENT_API) |         attrs.setdefault("intent", TokenIntents.INTENT_API) | ||||||
|         if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]: |         if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]: | ||||||
|             raise ValidationError({"intent": f"Invalid intent {attrs.get('intent')}"}) |             raise ValidationError({"intent": f"Invalid intent {attrs.get('intent')}"}) | ||||||
|  |  | ||||||
|         if attrs.get("intent") == TokenIntents.INTENT_APP_PASSWORD: |  | ||||||
|             # user IS in attrs |  | ||||||
|             user: User = attrs.get("user") |  | ||||||
|             max_token_lifetime = user.group_attributes(request).get( |  | ||||||
|                 USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, |  | ||||||
|             ) |  | ||||||
|             max_token_lifetime_dt = default_token_duration() |  | ||||||
|             if max_token_lifetime is not None: |  | ||||||
|                 try: |  | ||||||
|                     max_token_lifetime_dt = now() + timedelta_from_string(max_token_lifetime) |  | ||||||
|                 except ValueError: |  | ||||||
|                     pass |  | ||||||
|  |  | ||||||
|             if "expires" in attrs and attrs.get("expires") > max_token_lifetime_dt: |  | ||||||
|                 raise ValidationError( |  | ||||||
|                     { |  | ||||||
|                         "expires": ( |  | ||||||
|                             f"Token expires exceeds maximum lifetime ({max_token_lifetime_dt} UTC)." |  | ||||||
|                         ) |  | ||||||
|                     } |  | ||||||
|                 ) |  | ||||||
|         elif attrs.get("intent") == TokenIntents.INTENT_API: |  | ||||||
|             # For API tokens, expires cannot be overridden |  | ||||||
|             attrs["expires"] = default_token_duration() |  | ||||||
|  |  | ||||||
|         return attrs |         return attrs | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|  | |||||||
| @ -85,7 +85,7 @@ class UserGroupSerializer(ModelSerializer): | |||||||
|     """Simplified Group Serializer for user's groups""" |     """Simplified Group Serializer for user's groups""" | ||||||
|  |  | ||||||
|     attributes = JSONDictField(required=False) |     attributes = JSONDictField(required=False) | ||||||
|     parent_name = CharField(source="parent.name", read_only=True, allow_null=True) |     parent_name = CharField(source="parent.name", read_only=True) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = Group |         model = Group | ||||||
| @ -113,26 +113,13 @@ class UserSerializer(ModelSerializer): | |||||||
|         queryset=Group.objects.all().order_by("name"), |         queryset=Group.objects.all().order_by("name"), | ||||||
|         default=list, |         default=list, | ||||||
|     ) |     ) | ||||||
|     groups_obj = SerializerMethodField(allow_null=True) |     groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups") | ||||||
|     uid = CharField(read_only=True) |     uid = CharField(read_only=True) | ||||||
|     username = CharField( |     username = CharField( | ||||||
|         max_length=150, |         max_length=150, | ||||||
|         validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))], |         validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))], | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def _should_include_groups(self) -> bool: |  | ||||||
|         request: Request = self.context.get("request", None) |  | ||||||
|         if not request: |  | ||||||
|             return True |  | ||||||
|         return str(request.query_params.get("include_groups", "true")).lower() == "true" |  | ||||||
|  |  | ||||||
|     @extend_schema_field(UserGroupSerializer(many=True)) |  | ||||||
|     def get_groups_obj(self, instance: User) -> list[UserGroupSerializer] | None: |  | ||||||
|         if not self._should_include_groups: |  | ||||||
|             return None |  | ||||||
|         return UserGroupSerializer(instance.ak_groups, many=True).data |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: |         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||||
| @ -407,19 +394,8 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     search_fields = ["username", "name", "is_active", "email", "uuid"] |     search_fields = ["username", "name", "is_active", "email", "uuid"] | ||||||
|     filterset_class = UsersFilter |     filterset_class = UsersFilter | ||||||
|  |  | ||||||
|     def get_queryset(self): |     def get_queryset(self):  # pragma: no cover | ||||||
|         base_qs = User.objects.all().exclude_anonymous() |         return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") | ||||||
|         if self.serializer_class(context={"request": self.request})._should_include_groups: |  | ||||||
|             base_qs = base_qs.prefetch_related("ak_groups") |  | ||||||
|         return base_qs |  | ||||||
|  |  | ||||||
|     @extend_schema( |  | ||||||
|         parameters=[ |  | ||||||
|             OpenApiParameter("include_groups", bool, default=True), |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
|     def list(self, request, *args, **kwargs): |  | ||||||
|         return super().list(request, *args, **kwargs) |  | ||||||
|  |  | ||||||
|     def _create_recovery_link(self) -> tuple[str, Token]: |     def _create_recovery_link(self) -> tuple[str, Token]: | ||||||
|         """Create a recovery link (when the current brand has a recovery flow set), |         """Create a recovery link (when the current brand has a recovery flow set), | ||||||
|  | |||||||
| @ -1,34 +1,10 @@ | |||||||
| """custom runserver command""" | """custom runserver command""" | ||||||
|  |  | ||||||
| from typing import TextIO |  | ||||||
|  |  | ||||||
| from daphne.management.commands.runserver import Command as RunServer | from daphne.management.commands.runserver import Command as RunServer | ||||||
| from daphne.server import Server |  | ||||||
|  |  | ||||||
| from authentik.root.signals import post_startup, pre_startup, startup |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SignalServer(Server): |  | ||||||
|     """Server which signals back to authentik when it finished starting up""" |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
|  |  | ||||||
|         def ready_callable(): |  | ||||||
|             pre_startup.send(sender=self) |  | ||||||
|             startup.send(sender=self) |  | ||||||
|             post_startup.send(sender=self) |  | ||||||
|  |  | ||||||
|         self.ready_callable = ready_callable |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Command(RunServer): | class Command(RunServer): | ||||||
|     """custom runserver command, which doesn't show the misleading django startup message""" |     """custom runserver command, which doesn't show the misleading django startup message""" | ||||||
|  |  | ||||||
|     server_cls = SignalServer |     def on_bind(self, server_port): | ||||||
|  |         pass | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
|         # Redirect standard stdout banner from Daphne into the void |  | ||||||
|         # as there are a couple more steps that happen before startup is fully done |  | ||||||
|         self.stdout = TextIO() |  | ||||||
|  | |||||||
| @ -5,7 +5,6 @@ from django.db import migrations, models | |||||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||||
|  |  | ||||||
| import authentik.core.models | import authentik.core.models | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def set_default_token_key(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | def set_default_token_key(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
| @ -17,10 +16,6 @@ def set_default_token_key(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|         token.save() |         token.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| def default_token_key(): |  | ||||||
|     return generate_id(60) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): | class Migration(migrations.Migration): | ||||||
|     replaces = [ |     replaces = [ | ||||||
|         ("authentik_core", "0012_auto_20201003_1737"), |         ("authentik_core", "0012_auto_20201003_1737"), | ||||||
| @ -67,7 +62,7 @@ class Migration(migrations.Migration): | |||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="token", |             model_name="token", | ||||||
|             name="key", |             name="key", | ||||||
|             field=models.TextField(default=default_token_key), |             field=models.TextField(default=authentik.core.models.default_token_key), | ||||||
|         ), |         ), | ||||||
|         migrations.AlterUniqueTogether( |         migrations.AlterUniqueTogether( | ||||||
|             name="token", |             name="token", | ||||||
|  | |||||||
| @ -1,31 +0,0 @@ | |||||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
| import authentik.core.models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_core", "0033_alter_user_options"), |  | ||||||
|         ("authentik_tenants", "0002_tenant_default_token_duration_and_more"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="authenticatedsession", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="token", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="token", |  | ||||||
|             name="key", |  | ||||||
|             field=models.TextField(default=authentik.core.models.default_token_key), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,52 +0,0 @@ | |||||||
| # Generated by Django 5.0.4 on 2024-04-15 11:28 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("auth", "0012_alter_user_first_name_max_length"), |  | ||||||
|         ("authentik_core", "0034_alter_authenticatedsession_expires_and_more"), |  | ||||||
|         ("authentik_rbac", "0003_alter_systempermission_options"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterModelOptions( |  | ||||||
|             name="group", |  | ||||||
|             options={ |  | ||||||
|                 "permissions": [ |  | ||||||
|                     ("add_user_to_group", "Add user to group"), |  | ||||||
|                     ("remove_user_from_group", "Remove user from group"), |  | ||||||
|                 ], |  | ||||||
|                 "verbose_name": "Group", |  | ||||||
|                 "verbose_name_plural": "Groups", |  | ||||||
|             }, |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="group", |  | ||||||
|             index=models.Index(fields=["name"], name="authentik_c_name_9ba8e4_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="user", |  | ||||||
|             index=models.Index(fields=["last_login"], name="authentik_c_last_lo_f0179a_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="user", |  | ||||||
|             index=models.Index( |  | ||||||
|                 fields=["password_change_date"], name="authentik_c_passwor_eec915_idx" |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="user", |  | ||||||
|             index=models.Index(fields=["uuid"], name="authentik_c_uuid_3dae2f_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="user", |  | ||||||
|             index=models.Index(fields=["path"], name="authentik_c_path_b1f502_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="user", |  | ||||||
|             index=models.Index(fields=["type"], name="authentik_c_type_ecf60d_idx"), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik core models""" | """authentik core models""" | ||||||
|  |  | ||||||
| from datetime import datetime | from datetime import timedelta | ||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
| from typing import Any, Optional, Self | from typing import Any, Optional, Self | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
| @ -25,16 +25,15 @@ from authentik.blueprints.models import ManagedModel | |||||||
| from authentik.core.exceptions import PropertyMappingExpressionException | from authentik.core.exceptions import PropertyMappingExpressionException | ||||||
| from authentik.core.types import UILoginButton, UserSettingSerializer | from authentik.core.types import UILoginButton, UserSettingSerializer | ||||||
| from authentik.lib.avatars import get_avatar | from authentik.lib.avatars import get_avatar | ||||||
|  | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.models import ( | from authentik.lib.models import ( | ||||||
|     CreatedUpdatedModel, |     CreatedUpdatedModel, | ||||||
|     DomainlessFormattedURLValidator, |     DomainlessFormattedURLValidator, | ||||||
|     SerializerModel, |     SerializerModel, | ||||||
| ) | ) | ||||||
| from authentik.lib.utils.time import timedelta_from_string |  | ||||||
| from authentik.policies.models import PolicyBindingModel | from authentik.policies.models import PolicyBindingModel | ||||||
| from authentik.tenants.models import DEFAULT_TOKEN_DURATION, DEFAULT_TOKEN_LENGTH | from authentik.tenants.utils import get_unique_identifier | ||||||
| from authentik.tenants.utils import get_current_tenant, get_unique_identifier |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| USER_ATTRIBUTE_DEBUG = "goauthentik.io/user/debug" | USER_ATTRIBUTE_DEBUG = "goauthentik.io/user/debug" | ||||||
| @ -43,42 +42,33 @@ USER_ATTRIBUTE_EXPIRES = "goauthentik.io/user/expires" | |||||||
| USER_ATTRIBUTE_DELETE_ON_LOGOUT = "goauthentik.io/user/delete-on-logout" | USER_ATTRIBUTE_DELETE_ON_LOGOUT = "goauthentik.io/user/delete-on-logout" | ||||||
| USER_ATTRIBUTE_SOURCES = "goauthentik.io/user/sources" | USER_ATTRIBUTE_SOURCES = "goauthentik.io/user/sources" | ||||||
| USER_ATTRIBUTE_TOKEN_EXPIRING = "goauthentik.io/user/token-expires"  # nosec | USER_ATTRIBUTE_TOKEN_EXPIRING = "goauthentik.io/user/token-expires"  # nosec | ||||||
| USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME = "goauthentik.io/user/token-maximum-lifetime"  # nosec |  | ||||||
| USER_ATTRIBUTE_CHANGE_USERNAME = "goauthentik.io/user/can-change-username" | USER_ATTRIBUTE_CHANGE_USERNAME = "goauthentik.io/user/can-change-username" | ||||||
| USER_ATTRIBUTE_CHANGE_NAME = "goauthentik.io/user/can-change-name" | USER_ATTRIBUTE_CHANGE_NAME = "goauthentik.io/user/can-change-name" | ||||||
| USER_ATTRIBUTE_CHANGE_EMAIL = "goauthentik.io/user/can-change-email" | USER_ATTRIBUTE_CHANGE_EMAIL = "goauthentik.io/user/can-change-email" | ||||||
| USER_PATH_SYSTEM_PREFIX = "goauthentik.io" | USER_PATH_SYSTEM_PREFIX = "goauthentik.io" | ||||||
| USER_PATH_SERVICE_ACCOUNT = USER_PATH_SYSTEM_PREFIX + "/service-accounts" | USER_PATH_SERVICE_ACCOUNT = USER_PATH_SYSTEM_PREFIX + "/service-accounts" | ||||||
|  |  | ||||||
|  |  | ||||||
| options.DEFAULT_NAMES = options.DEFAULT_NAMES + ( | options.DEFAULT_NAMES = options.DEFAULT_NAMES + ( | ||||||
|     # used_by API that allows models to specify if they shadow an object |     # used_by API that allows models to specify if they shadow an object | ||||||
|     # for example the proxy provider which is built on top of an oauth provider |     # for example the proxy provider which is built on top of an oauth provider | ||||||
|     "authentik_used_by_shadows", |     "authentik_used_by_shadows", | ||||||
|  |     # List fields for which changes are not logged (due to them having dedicated objects) | ||||||
|  |     # for example user's password and last_login | ||||||
|  |     "authentik_signals_ignored_fields", | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def default_token_duration() -> datetime: | def default_token_duration(): | ||||||
|     """Default duration a Token is valid""" |     """Default duration a Token is valid""" | ||||||
|     current_tenant = get_current_tenant() |     return now() + timedelta(minutes=30) | ||||||
|     token_duration = ( |  | ||||||
|         current_tenant.default_token_duration |  | ||||||
|         if hasattr(current_tenant, "default_token_duration") |  | ||||||
|         else DEFAULT_TOKEN_DURATION |  | ||||||
|     ) |  | ||||||
|     return now() + timedelta_from_string(token_duration) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def default_token_key() -> str: | def default_token_key(): | ||||||
|     """Default token key""" |     """Default token key""" | ||||||
|     current_tenant = get_current_tenant() |  | ||||||
|     token_length = ( |  | ||||||
|         current_tenant.default_token_length |  | ||||||
|         if hasattr(current_tenant, "default_token_length") |  | ||||||
|         else DEFAULT_TOKEN_LENGTH |  | ||||||
|     ) |  | ||||||
|     # We use generate_id since the chars in the key should be easy |     # We use generate_id since the chars in the key should be easy | ||||||
|     # to use in Emails (for verification) and URLs (for recovery) |     # to use in Emails (for verification) and URLs (for recovery) | ||||||
|     return generate_id(token_length) |     return generate_id(CONFIG.get_int("default_token_length")) | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserTypes(models.TextChoices): | class UserTypes(models.TextChoices): | ||||||
| @ -177,13 +167,8 @@ class Group(SerializerModel): | |||||||
|                 "parent", |                 "parent", | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|         indexes = [models.Index(fields=["name"])] |  | ||||||
|         verbose_name = _("Group") |         verbose_name = _("Group") | ||||||
|         verbose_name_plural = _("Groups") |         verbose_name_plural = _("Groups") | ||||||
|         permissions = [ |  | ||||||
|             ("add_user_to_group", _("Add user to group")), |  | ||||||
|             ("remove_user_from_group", _("Remove user from group")), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserQuerySet(models.QuerySet): | class UserQuerySet(models.QuerySet): | ||||||
| @ -320,12 +305,13 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): | |||||||
|             ("preview_user", _("Can preview user data sent to providers")), |             ("preview_user", _("Can preview user data sent to providers")), | ||||||
|             ("view_user_applications", _("View applications the user has access to")), |             ("view_user_applications", _("View applications the user has access to")), | ||||||
|         ] |         ] | ||||||
|         indexes = [ |         authentik_signals_ignored_fields = [ | ||||||
|             models.Index(fields=["last_login"]), |             # Logged by the events `password_set` | ||||||
|             models.Index(fields=["password_change_date"]), |             # the `password_set` action/signal doesn't currently convey which user | ||||||
|             models.Index(fields=["uuid"]), |             # initiated the password change, so for now we'll log two actions | ||||||
|             models.Index(fields=["path"]), |             # ("password", "password_change_date"), | ||||||
|             models.Index(fields=["type"]), |             # Logged by `login` | ||||||
|  |             ("last_login",), | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -632,7 +618,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"User-source connection (user={self.user_id}, source={self.source_id})" |         return f"User-source connection (user={self.user.username}, source={self.source.slug})" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         unique_together = (("user", "source"),) |         unique_together = (("user", "source"),) | ||||||
| @ -641,7 +627,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel): | |||||||
| class ExpiringModel(models.Model): | class ExpiringModel(models.Model): | ||||||
|     """Base Model which can expire, and is automatically cleaned up.""" |     """Base Model which can expire, and is automatically cleaned up.""" | ||||||
|  |  | ||||||
|     expires = models.DateTimeField(default=None, null=True) |     expires = models.DateTimeField(default=default_token_duration) | ||||||
|     expiring = models.BooleanField(default=True) |     expiring = models.BooleanField(default=True) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
| @ -655,7 +641,7 @@ class ExpiringModel(models.Model): | |||||||
|         return self.delete(*args, **kwargs) |         return self.delete(*args, **kwargs) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def filter_not_expired(cls, **kwargs) -> QuerySet["Token"]: |     def filter_not_expired(cls, **kwargs) -> QuerySet: | ||||||
|         """Filer for tokens which are not expired yet or are not expiring, |         """Filer for tokens which are not expired yet or are not expiring, | ||||||
|         and match filters in `kwargs`""" |         and match filters in `kwargs`""" | ||||||
|         for obj in cls.objects.filter(**kwargs).filter(Q(expires__lt=now(), expiring=True)): |         for obj in cls.objects.filter(**kwargs).filter(Q(expires__lt=now(), expiring=True)): | ||||||
|  | |||||||
| @ -10,14 +10,7 @@ from django.dispatch import receiver | |||||||
| from django.http.request import HttpRequest | from django.http.request import HttpRequest | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import Application, AuthenticatedSession, BackchannelProvider, User | ||||||
|     Application, |  | ||||||
|     AuthenticatedSession, |  | ||||||
|     BackchannelProvider, |  | ||||||
|     ExpiringModel, |  | ||||||
|     User, |  | ||||||
|     default_token_duration, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| # Arguments: user: User, password: str | # Arguments: user: User, password: str | ||||||
| password_changed = Signal() | password_changed = Signal() | ||||||
| @ -68,12 +61,3 @@ def backchannel_provider_pre_save(sender: type[Model], instance: Model, **_): | |||||||
|     if not isinstance(instance, BackchannelProvider): |     if not isinstance(instance, BackchannelProvider): | ||||||
|         return |         return | ||||||
|     instance.is_backchannel = True |     instance.is_backchannel = True | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save) |  | ||||||
| def expiring_model_pre_save(sender: type[Model], instance: Model, **_): |  | ||||||
|     """Ensure expires is set on ExpiringModels that are set to expire""" |  | ||||||
|     if not issubclass(sender, ExpiringModel): |  | ||||||
|         return |  | ||||||
|     if instance.expiring and instance.expires is None: |  | ||||||
|         instance.expires = default_token_duration() |  | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from django.utils.translation import gettext as _ | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection | from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection | ||||||
| from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage | from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.flows.exceptions import FlowNonApplicableException | from authentik.flows.exceptions import FlowNonApplicableException | ||||||
| from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage | from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage | ||||||
| @ -100,6 +100,8 @@ class SourceFlowManager: | |||||||
|         if self.request.user.is_authenticated: |         if self.request.user.is_authenticated: | ||||||
|             new_connection.user = self.request.user |             new_connection.user = self.request.user | ||||||
|             new_connection = self.update_connection(new_connection, **kwargs) |             new_connection = self.update_connection(new_connection, **kwargs) | ||||||
|  |  | ||||||
|  |             new_connection.save() | ||||||
|             return Action.LINK, new_connection |             return Action.LINK, new_connection | ||||||
|  |  | ||||||
|         existing_connections = self.connection_type.objects.filter( |         existing_connections = self.connection_type.objects.filter( | ||||||
| @ -146,6 +148,7 @@ class SourceFlowManager: | |||||||
|         ]: |         ]: | ||||||
|             new_connection.user = user |             new_connection.user = user | ||||||
|             new_connection = self.update_connection(new_connection, **kwargs) |             new_connection = self.update_connection(new_connection, **kwargs) | ||||||
|  |             new_connection.save() | ||||||
|             return Action.LINK, new_connection |             return Action.LINK, new_connection | ||||||
|         if self.source.user_matching_mode in [ |         if self.source.user_matching_mode in [ | ||||||
|             SourceUserMatchingModes.EMAIL_DENY, |             SourceUserMatchingModes.EMAIL_DENY, | ||||||
| @ -206,9 +209,13 @@ class SourceFlowManager: | |||||||
|  |  | ||||||
|     def get_stages_to_append(self, flow: Flow) -> list[Stage]: |     def get_stages_to_append(self, flow: Flow) -> list[Stage]: | ||||||
|         """Hook to override stages which are appended to the flow""" |         """Hook to override stages which are appended to the flow""" | ||||||
|         return [ |         if not self.source.enrollment_flow: | ||||||
|             in_memory_stage(PostSourceStage), |             return [] | ||||||
|         ] |         if flow.slug == self.source.enrollment_flow.slug: | ||||||
|  |             return [ | ||||||
|  |                 in_memory_stage(PostUserEnrollmentStage), | ||||||
|  |             ] | ||||||
|  |         return [] | ||||||
|  |  | ||||||
|     def _prepare_flow( |     def _prepare_flow( | ||||||
|         self, |         self, | ||||||
| @ -262,9 +269,6 @@ class SourceFlowManager: | |||||||
|             ) |             ) | ||||||
|         # We run the Flow planner here so we can pass the Pending user in the context |         # We run the Flow planner here so we can pass the Pending user in the context | ||||||
|         planner = FlowPlanner(flow) |         planner = FlowPlanner(flow) | ||||||
|         # We append some stages so the initial flow we get might be empty |  | ||||||
|         planner.allow_empty_flows = True |  | ||||||
|         planner.use_cache = False |  | ||||||
|         plan = planner.plan(self.request, kwargs) |         plan = planner.plan(self.request, kwargs) | ||||||
|         for stage in self.get_stages_to_append(flow): |         for stage in self.get_stages_to_append(flow): | ||||||
|             plan.append_stage(stage) |             plan.append_stage(stage) | ||||||
| @ -323,7 +327,7 @@ class SourceFlowManager: | |||||||
|             reverse( |             reverse( | ||||||
|                 "authentik_core:if-user", |                 "authentik_core:if-user", | ||||||
|             ) |             ) | ||||||
|             + "#/settings;page-sources" |             + f"#/settings;page-{self.source.slug}" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def handle_enroll( |     def handle_enroll( | ||||||
|  | |||||||
| @ -10,7 +10,7 @@ from authentik.flows.stage import StageView | |||||||
| PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection" | PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection" | ||||||
|  |  | ||||||
|  |  | ||||||
| class PostSourceStage(StageView): | class PostUserEnrollmentStage(StageView): | ||||||
|     """Dynamically injected stage which saves the Connection after |     """Dynamically injected stage which saves the Connection after | ||||||
|     the user has been enrolled.""" |     the user has been enrolled.""" | ||||||
|  |  | ||||||
| @ -21,12 +21,10 @@ class PostSourceStage(StageView): | |||||||
|         ] |         ] | ||||||
|         user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] |         user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] | ||||||
|         connection.user = user |         connection.user = user | ||||||
|         linked = connection.pk is None |  | ||||||
|         connection.save() |         connection.save() | ||||||
|         if linked: |         Event.new( | ||||||
|             Event.new( |             EventAction.SOURCE_LINKED, | ||||||
|                 EventAction.SOURCE_LINKED, |             message="Linked Source", | ||||||
|                 message="Linked Source", |             source=connection.source, | ||||||
|                 source=connection.source, |         ).from_http(self.request) | ||||||
|             ).from_http(self.request) |  | ||||||
|         return self.executor.stage_ok() |         return self.executor.stage_ok() | ||||||
|  | |||||||
| @ -2,9 +2,7 @@ | |||||||
|  |  | ||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
|  |  | ||||||
| from django.conf import ImproperlyConfigured |  | ||||||
| from django.contrib.sessions.backends.cache import KEY_PREFIX | from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||||
| from django.contrib.sessions.backends.db import SessionStore as DBSessionStore |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -17,7 +15,6 @@ from authentik.core.models import ( | |||||||
|     User, |     User, | ||||||
| ) | ) | ||||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -42,31 +39,16 @@ def clean_expired_models(self: SystemTask): | |||||||
|     amount = 0 |     amount = 0 | ||||||
|  |  | ||||||
|     for session in AuthenticatedSession.objects.all(): |     for session in AuthenticatedSession.objects.all(): | ||||||
|         match CONFIG.get("session_storage", "cache"): |         cache_key = f"{KEY_PREFIX}{session.session_key}" | ||||||
|             case "cache": |         value = None | ||||||
|                 cache_key = f"{KEY_PREFIX}{session.session_key}" |         try: | ||||||
|                 value = None |             value = cache.get(cache_key) | ||||||
|                 try: |  | ||||||
|                     value = cache.get(cache_key) |  | ||||||
|  |  | ||||||
|                 except Exception as exc: |         except Exception as exc: | ||||||
|                     LOGGER.debug("Failed to get session from cache", exc=exc) |             LOGGER.debug("Failed to get session from cache", exc=exc) | ||||||
|                 if not value: |         if not value: | ||||||
|                     session.delete() |             session.delete() | ||||||
|                     amount += 1 |             amount += 1 | ||||||
|             case "db": |  | ||||||
|                 if not ( |  | ||||||
|                     DBSessionStore.get_model_class() |  | ||||||
|                     .objects.filter(session_key=session.session_key, expire_date__gt=now()) |  | ||||||
|                     .exists() |  | ||||||
|                 ): |  | ||||||
|                     session.delete() |  | ||||||
|                     amount += 1 |  | ||||||
|             case _: |  | ||||||
|                 # Should never happen, as we check for other values in authentik/root/settings.py |  | ||||||
|                 raise ImproperlyConfigured( |  | ||||||
|                     "Invalid session_storage setting, allowed values are db and cache" |  | ||||||
|                 ) |  | ||||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) |     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||||
|  |  | ||||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") |     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||||
|  | |||||||
| @ -1,11 +1,10 @@ | |||||||
| """Test Groups API""" | """Test Groups API""" | ||||||
|  |  | ||||||
| from django.urls.base import reverse | from django.urls.base import reverse | ||||||
| from guardian.shortcuts import assign_perm |  | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -13,22 +12,13 @@ class TestGroupsAPI(APITestCase): | |||||||
|     """Test Groups API""" |     """Test Groups API""" | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         self.login_user = create_test_user() |         self.admin = create_test_admin_user() | ||||||
|         self.user = User.objects.create(username="test-user") |         self.user = User.objects.create(username="test-user") | ||||||
|  |  | ||||||
|     def test_list_with_users(self): |  | ||||||
|         """Test listing with users""" |  | ||||||
|         admin = create_test_admin_user() |  | ||||||
|         self.client.force_login(admin) |  | ||||||
|         response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"}) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|     def test_add_user(self): |     def test_add_user(self): | ||||||
|         """Test add_user""" |         """Test add_user""" | ||||||
|         group = Group.objects.create(name=generate_id()) |         group = Group.objects.create(name=generate_id()) | ||||||
|         assign_perm("authentik_core.add_user_to_group", self.login_user, group) |         self.client.force_login(self.admin) | ||||||
|         assign_perm("authentik_core.view_user", self.login_user) |  | ||||||
|         self.client.force_login(self.login_user) |  | ||||||
|         res = self.client.post( |         res = self.client.post( | ||||||
|             reverse("authentik_api:group-add-user", kwargs={"pk": group.pk}), |             reverse("authentik_api:group-add-user", kwargs={"pk": group.pk}), | ||||||
|             data={ |             data={ | ||||||
| @ -42,9 +32,7 @@ class TestGroupsAPI(APITestCase): | |||||||
|     def test_add_user_404(self): |     def test_add_user_404(self): | ||||||
|         """Test add_user""" |         """Test add_user""" | ||||||
|         group = Group.objects.create(name=generate_id()) |         group = Group.objects.create(name=generate_id()) | ||||||
|         assign_perm("authentik_core.add_user_to_group", self.login_user, group) |         self.client.force_login(self.admin) | ||||||
|         assign_perm("authentik_core.view_user", self.login_user) |  | ||||||
|         self.client.force_login(self.login_user) |  | ||||||
|         res = self.client.post( |         res = self.client.post( | ||||||
|             reverse("authentik_api:group-add-user", kwargs={"pk": group.pk}), |             reverse("authentik_api:group-add-user", kwargs={"pk": group.pk}), | ||||||
|             data={ |             data={ | ||||||
| @ -56,10 +44,8 @@ class TestGroupsAPI(APITestCase): | |||||||
|     def test_remove_user(self): |     def test_remove_user(self): | ||||||
|         """Test remove_user""" |         """Test remove_user""" | ||||||
|         group = Group.objects.create(name=generate_id()) |         group = Group.objects.create(name=generate_id()) | ||||||
|         assign_perm("authentik_core.remove_user_from_group", self.login_user, group) |  | ||||||
|         assign_perm("authentik_core.view_user", self.login_user) |  | ||||||
|         group.users.add(self.user) |         group.users.add(self.user) | ||||||
|         self.client.force_login(self.login_user) |         self.client.force_login(self.admin) | ||||||
|         res = self.client.post( |         res = self.client.post( | ||||||
|             reverse("authentik_api:group-remove-user", kwargs={"pk": group.pk}), |             reverse("authentik_api:group-remove-user", kwargs={"pk": group.pk}), | ||||||
|             data={ |             data={ | ||||||
| @ -73,10 +59,8 @@ class TestGroupsAPI(APITestCase): | |||||||
|     def test_remove_user_404(self): |     def test_remove_user_404(self): | ||||||
|         """Test remove_user""" |         """Test remove_user""" | ||||||
|         group = Group.objects.create(name=generate_id()) |         group = Group.objects.create(name=generate_id()) | ||||||
|         assign_perm("authentik_core.remove_user_from_group", self.login_user, group) |  | ||||||
|         assign_perm("authentik_core.view_user", self.login_user) |  | ||||||
|         group.users.add(self.user) |         group.users.add(self.user) | ||||||
|         self.client.force_login(self.login_user) |         self.client.force_login(self.admin) | ||||||
|         res = self.client.post( |         res = self.client.post( | ||||||
|             reverse("authentik_api:group-remove-user", kwargs={"pk": group.pk}), |             reverse("authentik_api:group-remove-user", kwargs={"pk": group.pk}), | ||||||
|             data={ |             data={ | ||||||
| @ -88,12 +72,11 @@ class TestGroupsAPI(APITestCase): | |||||||
|     def test_parent_self(self): |     def test_parent_self(self): | ||||||
|         """Test parent""" |         """Test parent""" | ||||||
|         group = Group.objects.create(name=generate_id()) |         group = Group.objects.create(name=generate_id()) | ||||||
|         assign_perm("view_group", self.login_user, group) |         self.client.force_login(self.admin) | ||||||
|         assign_perm("change_group", self.login_user, group) |  | ||||||
|         self.client.force_login(self.login_user) |  | ||||||
|         res = self.client.patch( |         res = self.client.patch( | ||||||
|             reverse("authentik_api:group-detail", kwargs={"pk": group.pk}), |             reverse("authentik_api:group-detail", kwargs={"pk": group.pk}), | ||||||
|             data={ |             data={ | ||||||
|  |                 "pk": self.user.pk + 3, | ||||||
|                 "parent": group.pk, |                 "parent": group.pk, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -2,15 +2,11 @@ | |||||||
|  |  | ||||||
| from django.contrib.auth.models import AnonymousUser | from django.contrib.auth.models import AnonymousUser | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.urls import reverse |  | ||||||
| from guardian.utils import get_anonymous_user | from guardian.utils import get_anonymous_user | ||||||
|  |  | ||||||
| from authentik.core.models import SourceUserMatchingModes, User | from authentik.core.models import SourceUserMatchingModes, User | ||||||
| from authentik.core.sources.flow_manager import Action | from authentik.core.sources.flow_manager import Action | ||||||
| from authentik.core.sources.stage import PostSourceStage |  | ||||||
| from authentik.core.tests.utils import create_test_flow | from authentik.core.tests.utils import create_test_flow | ||||||
| from authentik.flows.planner import FlowPlan |  | ||||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN |  | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.tests.utils import get_request | from authentik.lib.tests.utils import get_request | ||||||
| from authentik.policies.denied import AccessDeniedResponse | from authentik.policies.denied import AccessDeniedResponse | ||||||
| @ -25,62 +21,42 @@ class TestSourceFlowManager(TestCase): | |||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.authentication_flow = create_test_flow() |         self.source: OAuthSource = OAuthSource.objects.create(name="test") | ||||||
|         self.enrollment_flow = create_test_flow() |  | ||||||
|         self.source: OAuthSource = OAuthSource.objects.create( |  | ||||||
|             name=generate_id(), |  | ||||||
|             slug=generate_id(), |  | ||||||
|             authentication_flow=self.authentication_flow, |  | ||||||
|             enrollment_flow=self.enrollment_flow, |  | ||||||
|         ) |  | ||||||
|         self.identifier = generate_id() |         self.identifier = generate_id() | ||||||
|  |  | ||||||
|     def test_unauthenticated_enroll(self): |     def test_unauthenticated_enroll(self): | ||||||
|         """Test un-authenticated user enrolling""" |         """Test un-authenticated user enrolling""" | ||||||
|         request = get_request("/", user=AnonymousUser()) |         flow_manager = OAuthSourceFlowManager( | ||||||
|         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) |             self.source, get_request("/", user=AnonymousUser()), self.identifier, {} | ||||||
|  |         ) | ||||||
|         action, _ = flow_manager.get_action() |         action, _ = flow_manager.get_action() | ||||||
|         self.assertEqual(action, Action.ENROLL) |         self.assertEqual(action, Action.ENROLL) | ||||||
|         response = flow_manager.get_flow() |         flow_manager.get_flow() | ||||||
|         self.assertEqual(response.status_code, 302) |  | ||||||
|         flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] |  | ||||||
|         self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) |  | ||||||
|  |  | ||||||
|     def test_unauthenticated_auth(self): |     def test_unauthenticated_auth(self): | ||||||
|         """Test un-authenticated user authenticating""" |         """Test un-authenticated user authenticating""" | ||||||
|         UserOAuthSourceConnection.objects.create( |         UserOAuthSourceConnection.objects.create( | ||||||
|             user=get_anonymous_user(), source=self.source, identifier=self.identifier |             user=get_anonymous_user(), source=self.source, identifier=self.identifier | ||||||
|         ) |         ) | ||||||
|         request = get_request("/", user=AnonymousUser()) |  | ||||||
|         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) |         flow_manager = OAuthSourceFlowManager( | ||||||
|  |             self.source, get_request("/", user=AnonymousUser()), self.identifier, {} | ||||||
|  |         ) | ||||||
|         action, _ = flow_manager.get_action() |         action, _ = flow_manager.get_action() | ||||||
|         self.assertEqual(action, Action.AUTH) |         self.assertEqual(action, Action.AUTH) | ||||||
|         response = flow_manager.get_flow() |         flow_manager.get_flow() | ||||||
|         self.assertEqual(response.status_code, 302) |  | ||||||
|         flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] |  | ||||||
|         self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) |  | ||||||
|  |  | ||||||
|     def test_authenticated_link(self): |     def test_authenticated_link(self): | ||||||
|         """Test authenticated user linking""" |         """Test authenticated user linking""" | ||||||
|         user = User.objects.create(username="foo", email="foo@bar.baz") |         UserOAuthSourceConnection.objects.create( | ||||||
|         request = get_request("/", user=user) |             user=get_anonymous_user(), source=self.source, identifier=self.identifier | ||||||
|         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) |  | ||||||
|         action, connection = flow_manager.get_action() |  | ||||||
|         self.assertEqual(action, Action.LINK) |  | ||||||
|         self.assertIsNone(connection.pk) |  | ||||||
|         response = flow_manager.get_flow() |  | ||||||
|         self.assertEqual(response.status_code, 302) |  | ||||||
|         self.assertEqual( |  | ||||||
|             response.url, |  | ||||||
|             reverse("authentik_core:if-user") + "#/settings;page-sources", |  | ||||||
|         ) |         ) | ||||||
|  |         user = User.objects.create(username="foo", email="foo@bar.baz") | ||||||
|     def test_unauthenticated_link(self): |         flow_manager = OAuthSourceFlowManager( | ||||||
|         """Test un-authenticated user linking""" |             self.source, get_request("/", user=user), self.identifier, {} | ||||||
|         flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {}) |         ) | ||||||
|         action, connection = flow_manager.get_action() |         action, _ = flow_manager.get_action() | ||||||
|         self.assertEqual(action, Action.LINK) |         self.assertEqual(action, Action.LINK) | ||||||
|         self.assertIsNone(connection.pk) |  | ||||||
|         flow_manager.get_flow() |         flow_manager.get_flow() | ||||||
|  |  | ||||||
|     def test_unauthenticated_enroll_email(self): |     def test_unauthenticated_enroll_email(self): | ||||||
|  | |||||||
| @ -1,6 +1,5 @@ | |||||||
| """Test token API""" | """Test token API""" | ||||||
|  |  | ||||||
| from datetime import datetime, timedelta |  | ||||||
| from json import loads | from json import loads | ||||||
|  |  | ||||||
| from django.urls.base import reverse | from django.urls.base import reverse | ||||||
| @ -8,13 +7,8 @@ from guardian.shortcuts import get_anonymous_user | |||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.api.tokens import TokenSerializer | from authentik.core.api.tokens import TokenSerializer | ||||||
| from authentik.core.models import ( | from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||||
|     USER_ATTRIBUTE_TOKEN_EXPIRING, | from authentik.core.tests.utils import create_test_admin_user | ||||||
|     USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, |  | ||||||
|     Token, |  | ||||||
|     TokenIntents, |  | ||||||
| ) |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user |  | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -23,7 +17,7 @@ class TestTokenAPI(APITestCase): | |||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.user = create_test_user() |         self.user = User.objects.create(username="testuser") | ||||||
|         self.admin = create_test_admin_user() |         self.admin = create_test_admin_user() | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|  |  | ||||||
| @ -82,95 +76,6 @@ class TestTokenAPI(APITestCase): | |||||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) |         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||||
|         self.assertEqual(token.expiring, False) |         self.assertEqual(token.expiring, False) | ||||||
|  |  | ||||||
|     def test_token_create_expiring(self): |  | ||||||
|         """Test token creation endpoint""" |  | ||||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True |  | ||||||
|         self.user.save() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:token-list"), {"identifier": "test-token"} |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         token = Token.objects.get(identifier="test-token") |  | ||||||
|         self.assertEqual(token.user, self.user) |  | ||||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) |  | ||||||
|         self.assertEqual(token.expiring, True) |  | ||||||
|  |  | ||||||
|     def test_token_create_expiring_custom_ok(self): |  | ||||||
|         """Test token creation endpoint""" |  | ||||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True |  | ||||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME] = "hours=2" |  | ||||||
|         self.user.save() |  | ||||||
|         expires = datetime.now() + timedelta(hours=1) |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:token-list"), |  | ||||||
|             { |  | ||||||
|                 "identifier": "test-token", |  | ||||||
|                 "expires": expires, |  | ||||||
|                 "intent": TokenIntents.INTENT_APP_PASSWORD, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         token = Token.objects.get(identifier="test-token") |  | ||||||
|         self.assertEqual(token.user, self.user) |  | ||||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_APP_PASSWORD) |  | ||||||
|         self.assertEqual(token.expiring, True) |  | ||||||
|         self.assertEqual(token.expires.timestamp(), expires.timestamp()) |  | ||||||
|  |  | ||||||
|     def test_token_create_expiring_custom_nok(self): |  | ||||||
|         """Test token creation endpoint""" |  | ||||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True |  | ||||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME] = "hours=2" |  | ||||||
|         self.user.save() |  | ||||||
|         expires = datetime.now() + timedelta(hours=3) |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:token-list"), |  | ||||||
|             { |  | ||||||
|                 "identifier": "test-token", |  | ||||||
|                 "expires": expires, |  | ||||||
|                 "intent": TokenIntents.INTENT_APP_PASSWORD, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 400) |  | ||||||
|  |  | ||||||
|     def test_token_create_expiring_custom_api(self): |  | ||||||
|         """Test token creation endpoint""" |  | ||||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True |  | ||||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME] = "hours=2" |  | ||||||
|         self.user.save() |  | ||||||
|         expires = datetime.now() + timedelta(seconds=3) |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:token-list"), |  | ||||||
|             { |  | ||||||
|                 "identifier": "test-token", |  | ||||||
|                 "expires": expires, |  | ||||||
|                 "intent": TokenIntents.INTENT_API, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         token = Token.objects.get(identifier="test-token") |  | ||||||
|         self.assertEqual(token.user, self.user) |  | ||||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) |  | ||||||
|         self.assertEqual(token.expiring, True) |  | ||||||
|         self.assertNotEqual(token.expires.timestamp(), expires.timestamp()) |  | ||||||
|  |  | ||||||
|     def test_token_change_user(self): |  | ||||||
|         """Test creating a token and then changing the user""" |  | ||||||
|         ident = generate_id() |  | ||||||
|         response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident}) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         token = Token.objects.get(identifier=ident) |  | ||||||
|         self.assertEqual(token.user, self.user) |  | ||||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) |  | ||||||
|         self.assertEqual(token.expiring, True) |  | ||||||
|         self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token)) |  | ||||||
|         response = self.client.put( |  | ||||||
|             reverse("authentik_api:token-detail", kwargs={"identifier": ident}), |  | ||||||
|             data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk}, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 400) |  | ||||||
|         token.refresh_from_db() |  | ||||||
|         self.assertEqual(token.user, self.user) |  | ||||||
|  |  | ||||||
|     def test_list(self): |     def test_list(self): | ||||||
|         """Test Token List (Test normal authentication)""" |         """Test Token List (Test normal authentication)""" | ||||||
|         Token.objects.all().delete() |         Token.objects.all().delete() | ||||||
|  | |||||||
| @ -41,12 +41,6 @@ class TestUsersAPI(APITestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|     def test_list_with_groups(self): |  | ||||||
|         """Test listing with groups""" |  | ||||||
|         self.client.force_login(self.admin) |  | ||||||
|         response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"}) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|     def test_metrics(self): |     def test_metrics(self): | ||||||
|         """Test user's metrics""" |         """Test user's metrics""" | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ from rest_framework.test import APITestCase | |||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
|  | from authentik.lib.config import CONFIG | ||||||
| from authentik.tenants.utils import get_current_tenant | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -24,6 +25,7 @@ class TestUsersAvatars(APITestCase): | |||||||
|         tenant.avatars = mode |         tenant.avatars = mode | ||||||
|         tenant.save() |         tenant.save() | ||||||
|  |  | ||||||
|  |     @CONFIG.patch("avatars", "none") | ||||||
|     def test_avatars_none(self): |     def test_avatars_none(self): | ||||||
|         """Test avatars none""" |         """Test avatars none""" | ||||||
|         self.set_avatar_mode("none") |         self.set_avatar_mode("none") | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ from django.utils.text import slugify | |||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg | from authentik.crypto.builder import CertificateBuilder | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.flows.models import Flow, FlowDesignation | from authentik.flows.models import Flow, FlowDesignation | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| @ -50,10 +50,12 @@ def create_test_brand(**kwargs) -> Brand: | |||||||
|     return Brand.objects.create(domain=uid, default=True, **kwargs) |     return Brand.objects.create(domain=uid, default=True, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_test_cert(alg=PrivateKeyAlg.RSA) -> CertificateKeyPair: | def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair: | ||||||
|     """Generate a certificate for testing""" |     """Generate a certificate for testing""" | ||||||
|     builder = CertificateBuilder(f"{generate_id()}.self-signed.goauthentik.io") |     builder = CertificateBuilder( | ||||||
|     builder.alg = alg |         name=f"{generate_id()}.self-signed.goauthentik.io", | ||||||
|  |         use_ec_private_key=use_ec_private_key, | ||||||
|  |     ) | ||||||
|     builder.build( |     builder.build( | ||||||
|         subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"], |         subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"], | ||||||
|         validity_days=360, |         validity_days=360, | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ from django.conf import settings | |||||||
| from django.contrib.auth.decorators import login_required | from django.contrib.auth.decorators import login_required | ||||||
| from django.urls import path | from django.urls import path | ||||||
| from django.views.decorators.csrf import ensure_csrf_cookie | from django.views.decorators.csrf import ensure_csrf_cookie | ||||||
| from django.views.generic import RedirectView |  | ||||||
|  |  | ||||||
| from authentik.core.api.applications import ApplicationViewSet | from authentik.core.api.applications import ApplicationViewSet | ||||||
| from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet | from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet | ||||||
| @ -20,7 +19,12 @@ from authentik.core.api.transactional_applications import TransactionalApplicati | |||||||
| from authentik.core.api.users import UserViewSet | from authentik.core.api.users import UserViewSet | ||||||
| from authentik.core.views import apps | from authentik.core.views import apps | ||||||
| from authentik.core.views.debug import AccessDeniedView | from authentik.core.views.debug import AccessDeniedView | ||||||
| from authentik.core.views.interface import FlowInterfaceView, InterfaceView | from authentik.core.views.interface import ( | ||||||
|  |     BrandDefaultRedirectView, | ||||||
|  |     FlowInterfaceView, | ||||||
|  |     InterfaceView, | ||||||
|  |     RootRedirectView, | ||||||
|  | ) | ||||||
| from authentik.core.views.session import EndSessionView | from authentik.core.views.session import EndSessionView | ||||||
| from authentik.root.asgi_middleware import SessionMiddleware | from authentik.root.asgi_middleware import SessionMiddleware | ||||||
| from authentik.root.messages.consumer import MessageConsumer | from authentik.root.messages.consumer import MessageConsumer | ||||||
| @ -29,13 +33,11 @@ from authentik.root.middleware import ChannelsLoggingMiddleware | |||||||
| urlpatterns = [ | urlpatterns = [ | ||||||
|     path( |     path( | ||||||
|         "", |         "", | ||||||
|         login_required( |         login_required(RootRedirectView.as_view()), | ||||||
|             RedirectView.as_view(pattern_name="authentik_core:if-user", query_string=True) |  | ||||||
|         ), |  | ||||||
|         name="root-redirect", |         name="root-redirect", | ||||||
|     ), |     ), | ||||||
|     path( |     path( | ||||||
|         # We have to use this format since everything else uses applications/o or applications/saml |         # We have to use this format since everything else uses application/o or application/saml | ||||||
|         "application/launch/<slug:application_slug>/", |         "application/launch/<slug:application_slug>/", | ||||||
|         apps.RedirectToAppLaunch.as_view(), |         apps.RedirectToAppLaunch.as_view(), | ||||||
|         name="application-launch", |         name="application-launch", | ||||||
| @ -43,12 +45,12 @@ urlpatterns = [ | |||||||
|     # Interfaces |     # Interfaces | ||||||
|     path( |     path( | ||||||
|         "if/admin/", |         "if/admin/", | ||||||
|         ensure_csrf_cookie(InterfaceView.as_view(template_name="if/admin.html")), |         ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/admin.html")), | ||||||
|         name="if-admin", |         name="if-admin", | ||||||
|     ), |     ), | ||||||
|     path( |     path( | ||||||
|         "if/user/", |         "if/user/", | ||||||
|         ensure_csrf_cookie(InterfaceView.as_view(template_name="if/user.html")), |         ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/user.html")), | ||||||
|         name="if-user", |         name="if-user", | ||||||
|     ), |     ), | ||||||
|     path( |     path( | ||||||
|  | |||||||
| @ -3,15 +3,43 @@ | |||||||
| from json import dumps | from json import dumps | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.shortcuts import get_object_or_404 | from django.http import HttpRequest | ||||||
| from django.views.generic.base import TemplateView | from django.http.response import HttpResponse | ||||||
|  | from django.shortcuts import get_object_or_404, redirect | ||||||
|  | from django.utils.translation import gettext as _ | ||||||
|  | from django.views.generic.base import RedirectView, TemplateView | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
|  |  | ||||||
| from authentik import get_build_hash | from authentik import get_build_hash | ||||||
| from authentik.admin.tasks import LOCAL_VERSION | from authentik.admin.tasks import LOCAL_VERSION | ||||||
| from authentik.api.v3.config import ConfigView | from authentik.api.v3.config import ConfigView | ||||||
| from authentik.brands.api import CurrentBrandSerializer | from authentik.brands.api import CurrentBrandSerializer | ||||||
|  | from authentik.brands.models import Brand | ||||||
|  | from authentik.core.models import UserTypes | ||||||
| from authentik.flows.models import Flow | from authentik.flows.models import Flow | ||||||
|  | from authentik.policies.denied import AccessDeniedResponse | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RootRedirectView(RedirectView): | ||||||
|  |     """Root redirect view, redirect to brand's default application if set""" | ||||||
|  |  | ||||||
|  |     pattern_name = "authentik_core:if-user" | ||||||
|  |     query_string = True | ||||||
|  |  | ||||||
|  |     def redirect_to_app(self, request: HttpRequest): | ||||||
|  |         if request.user.is_authenticated and request.user.type == UserTypes.EXTERNAL: | ||||||
|  |             brand: Brand = request.brand | ||||||
|  |             if brand.default_application: | ||||||
|  |                 return redirect( | ||||||
|  |                     "authentik_core:application-launch", | ||||||
|  |                     application_slug=brand.default_application.slug, | ||||||
|  |                 ) | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||||
|  |         if redirect_response := RootRedirectView().redirect_to_app(request): | ||||||
|  |             return redirect_response | ||||||
|  |         return super().dispatch(request, *args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| class InterfaceView(TemplateView): | class InterfaceView(TemplateView): | ||||||
| @ -27,6 +55,22 @@ class InterfaceView(TemplateView): | |||||||
|         return super().get_context_data(**kwargs) |         return super().get_context_data(**kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BrandDefaultRedirectView(InterfaceView): | ||||||
|  |     """By default redirect to default app""" | ||||||
|  |  | ||||||
|  |     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||||
|  |         if request.user.is_authenticated and request.user.type == UserTypes.EXTERNAL: | ||||||
|  |             brand: Brand = request.brand | ||||||
|  |             if brand.default_application: | ||||||
|  |                 return redirect( | ||||||
|  |                     "authentik_core:application-launch", | ||||||
|  |                     application_slug=brand.default_application.slug, | ||||||
|  |                 ) | ||||||
|  |             response = AccessDeniedResponse(self.request) | ||||||
|  |             response.error_message = _("Interface can only be accessed by internal users.") | ||||||
|  |         return super().dispatch(request, *args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| class FlowInterfaceView(InterfaceView): | class FlowInterfaceView(InterfaceView): | ||||||
|     """Flow interface""" |     """Flow interface""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -14,13 +14,7 @@ from drf_spectacular.types import OpenApiTypes | |||||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import ( | from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField | ||||||
|     CharField, |  | ||||||
|     ChoiceField, |  | ||||||
|     DateTimeField, |  | ||||||
|     IntegerField, |  | ||||||
|     SerializerMethodField, |  | ||||||
| ) |  | ||||||
| from rest_framework.filters import OrderingFilter, SearchFilter | from rest_framework.filters import OrderingFilter, SearchFilter | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| @ -32,7 +26,7 @@ from authentik.api.authorization import SecretKeyFilter | |||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.crypto.apps import MANAGED_KEY | from authentik.crypto.apps import MANAGED_KEY | ||||||
| from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg | from authentik.crypto.builder import CertificateBuilder | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
| @ -184,7 +178,6 @@ class CertificateGenerationSerializer(PassiveSerializer): | |||||||
|     common_name = CharField() |     common_name = CharField() | ||||||
|     subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name")) |     subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name")) | ||||||
|     validity_days = IntegerField(initial=365) |     validity_days = IntegerField(initial=365) | ||||||
|     alg = ChoiceField(default=PrivateKeyAlg.RSA, choices=PrivateKeyAlg.choices) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CertificateKeyPairFilter(FilterSet): | class CertificateKeyPairFilter(FilterSet): | ||||||
| @ -247,7 +240,6 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|         raw_san = data.validated_data.get("subject_alt_name", "") |         raw_san = data.validated_data.get("subject_alt_name", "") | ||||||
|         sans = raw_san.split(",") if raw_san != "" else [] |         sans = raw_san.split(",") if raw_san != "" else [] | ||||||
|         builder = CertificateBuilder(data.validated_data["common_name"]) |         builder = CertificateBuilder(data.validated_data["common_name"]) | ||||||
|         builder.alg = data.validated_data["alg"] |  | ||||||
|         builder.build( |         builder.build( | ||||||
|             subject_alt_names=sans, |             subject_alt_names=sans, | ||||||
|             validity_days=int(data.validated_data["validity_days"]), |             validity_days=int(data.validated_data["validity_days"]), | ||||||
|  | |||||||
| @ -9,28 +9,20 @@ from cryptography.hazmat.primitives import hashes, serialization | |||||||
| from cryptography.hazmat.primitives.asymmetric import ec, rsa | from cryptography.hazmat.primitives.asymmetric import ec, rsa | ||||||
| from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | ||||||
| from cryptography.x509.oid import NameOID | from cryptography.x509.oid import NameOID | ||||||
| from django.db import models |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
|  |  | ||||||
| from authentik import __version__ | from authentik import __version__ | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrivateKeyAlg(models.TextChoices): |  | ||||||
|     """Algorithm to create private key with""" |  | ||||||
|  |  | ||||||
|     RSA = "rsa", _("rsa") |  | ||||||
|     ECDSA = "ecdsa", _("ecdsa") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CertificateBuilder: | class CertificateBuilder: | ||||||
|     """Build self-signed certificates""" |     """Build self-signed certificates""" | ||||||
|  |  | ||||||
|     common_name: str |     common_name: str | ||||||
|     alg: PrivateKeyAlg |  | ||||||
|  |  | ||||||
|     def __init__(self, name: str): |     _use_ec_private_key: bool | ||||||
|         self.alg = PrivateKeyAlg.RSA |  | ||||||
|  |     def __init__(self, name: str, use_ec_private_key=False): | ||||||
|  |         self._use_ec_private_key = use_ec_private_key | ||||||
|         self.__public_key = None |         self.__public_key = None | ||||||
|         self.__private_key = None |         self.__private_key = None | ||||||
|         self.__builder = None |         self.__builder = None | ||||||
| @ -50,13 +42,11 @@ class CertificateBuilder: | |||||||
|  |  | ||||||
|     def generate_private_key(self) -> PrivateKeyTypes: |     def generate_private_key(self) -> PrivateKeyTypes: | ||||||
|         """Generate private key""" |         """Generate private key""" | ||||||
|         if self.alg == PrivateKeyAlg.ECDSA: |         if self._use_ec_private_key: | ||||||
|             return ec.generate_private_key(curve=ec.SECP256R1()) |             return ec.generate_private_key(curve=ec.SECP256R1()) | ||||||
|         if self.alg == PrivateKeyAlg.RSA: |         return rsa.generate_private_key( | ||||||
|             return rsa.generate_private_key( |             public_exponent=65537, key_size=4096, backend=default_backend() | ||||||
|                 public_exponent=65537, key_size=4096, backend=default_backend() |         ) | ||||||
|             ) |  | ||||||
|         raise ValueError(f"Invalid alg: {self.alg}") |  | ||||||
|  |  | ||||||
|     def build( |     def build( | ||||||
|         self, |         self, | ||||||
|  | |||||||
| @ -13,9 +13,9 @@ class AuthentikEnterpriseAuditConfig(EnterpriseConfig): | |||||||
|     verbose_name = "authentik Enterprise.Audit" |     verbose_name = "authentik Enterprise.Audit" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def ready(self): |     @EnterpriseConfig.reconcile_global | ||||||
|  |     def install_middleware(self): | ||||||
|         """Install enterprise audit middleware""" |         """Install enterprise audit middleware""" | ||||||
|         orig_import = "authentik.events.middleware.AuditMiddleware" |         orig_import = "authentik.events.middleware.AuditMiddleware" | ||||||
|         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" |         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" | ||||||
|         settings.MIDDLEWARE = [new_import if x == orig_import else x for x in settings.MIDDLEWARE] |         settings.MIDDLEWARE = [new_import if x == orig_import else x for x in settings.MIDDLEWARE] | ||||||
|         return super().ready() |  | ||||||
|  | |||||||
| @ -2,16 +2,16 @@ | |||||||
|  |  | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Any |  | ||||||
|  |  | ||||||
| from django.apps.registry import apps | from django.apps.registry import apps | ||||||
| from django.core.files import File | from django.core.files import File | ||||||
| from django.db import connection | from django.db import connection | ||||||
| from django.db.models import ManyToManyRel, Model | from django.db.models import Model | ||||||
| from django.db.models.expressions import BaseExpression, Combinable | from django.db.models.expressions import BaseExpression, Combinable | ||||||
| from django.db.models.signals import post_init | from django.db.models.signals import post_init | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
|  |  | ||||||
|  | from authentik.core.models import User | ||||||
| from authentik.events.middleware import AuditMiddleware, should_log_model | from authentik.events.middleware import AuditMiddleware, should_log_model | ||||||
| from authentik.events.utils import cleanse_dict, sanitize_item | from authentik.events.utils import cleanse_dict, sanitize_item | ||||||
|  |  | ||||||
| @ -28,10 +28,13 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         super().connect(request) |         super().connect(request) | ||||||
|         if not self.enabled: |         if not self.enabled: | ||||||
|             return |             return | ||||||
|  |         user = getattr(request, "user", self.anonymous_user) | ||||||
|  |         if not user.is_authenticated: | ||||||
|  |             user = self.anonymous_user | ||||||
|         if not hasattr(request, "request_id"): |         if not hasattr(request, "request_id"): | ||||||
|             return |             return | ||||||
|         post_init.connect( |         post_init.connect( | ||||||
|             partial(self.post_init_handler, request=request), |             partial(self.post_init_handler, user=user, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
| @ -45,7 +48,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         post_init.disconnect(dispatch_uid=request.request_id) |         post_init.disconnect(dispatch_uid=request.request_id) | ||||||
|  |  | ||||||
|     def serialize_simple(self, model: Model) -> dict: |     def serialize_simple(self, model: Model) -> dict: | ||||||
|         """Serialize a model in a very simple way. No ForeignKeys or other relationships are |         """Serialize a model in a very simple way. No ForeginKeys or other relationships are | ||||||
|         resolved""" |         resolved""" | ||||||
|         data = {} |         data = {} | ||||||
|         deferred_fields = model.get_deferred_fields() |         deferred_fields = model.get_deferred_fields() | ||||||
| @ -71,12 +74,9 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         for key, value in before.items(): |         for key, value in before.items(): | ||||||
|             if after.get(key) != value: |             if after.get(key) != value: | ||||||
|                 diff[key] = {"previous_value": value, "new_value": after.get(key)} |                 diff[key] = {"previous_value": value, "new_value": after.get(key)} | ||||||
|         for key, value in after.items(): |  | ||||||
|             if key not in before and key not in diff and before.get(key) != value: |  | ||||||
|                 diff[key] = {"previous_value": before.get(key), "new_value": value} |  | ||||||
|         return sanitize_item(diff) |         return sanitize_item(diff) | ||||||
|  |  | ||||||
|     def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): |     def post_init_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||||
|         """post_init django model handler""" |         """post_init django model handler""" | ||||||
|         if not should_log_model(instance): |         if not should_log_model(instance): | ||||||
|             return |             return | ||||||
| @ -90,6 +90,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|  |  | ||||||
|     def post_save_handler( |     def post_save_handler( | ||||||
|         self, |         self, | ||||||
|  |         user: User, | ||||||
|         request: HttpRequest, |         request: HttpRequest, | ||||||
|         sender, |         sender, | ||||||
|         instance: Model, |         instance: Model, | ||||||
| @ -102,37 +103,15 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         thread_kwargs = {} |         thread_kwargs = {} | ||||||
|         if hasattr(instance, "_previous_state") or created: |         if hasattr(instance, "_previous_state") or created: | ||||||
|             prev_state = getattr(instance, "_previous_state", {}) |             prev_state = getattr(instance, "_previous_state", {}) | ||||||
|             if created: |  | ||||||
|                 prev_state = {} |  | ||||||
|             # Get current state |             # Get current state | ||||||
|             new_state = self.serialize_simple(instance) |             new_state = self.serialize_simple(instance) | ||||||
|             diff = self.diff(prev_state, new_state) |             diff = self.diff(prev_state, new_state) | ||||||
|             thread_kwargs["diff"] = diff |             thread_kwargs["diff"] = diff | ||||||
|         return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) |             if not created: | ||||||
|  |                 ignored_field_sets = getattr(instance._meta, "authentik_signals_ignored_fields", []) | ||||||
|     def m2m_changed_handler(  # noqa: PLR0913 |                 for field_set in ignored_field_sets: | ||||||
|         self, |                     if set(diff.keys()) == set(field_set): | ||||||
|         request: HttpRequest, |                         return None | ||||||
|         sender, |         return super().post_save_handler( | ||||||
|         instance: Model, |             user, request, sender, instance, created, thread_kwargs, **_ | ||||||
|         action: str, |         ) | ||||||
|         pk_set: set[Any], |  | ||||||
|         thread_kwargs: dict | None = None, |  | ||||||
|         **_, |  | ||||||
|     ): |  | ||||||
|         thread_kwargs = {} |  | ||||||
|         m2m_field = None |  | ||||||
|         # For the audit log we don't care about `pre_` or `post_` so we trim that part off |  | ||||||
|         _, _, action_direction = action.partition("_") |  | ||||||
|         # resolve the "through" model to an actual field |  | ||||||
|         for field in instance._meta.get_fields(): |  | ||||||
|             if not isinstance(field, ManyToManyRel): |  | ||||||
|                 continue |  | ||||||
|             if field.through == sender: |  | ||||||
|                 m2m_field = field |  | ||||||
|         if m2m_field: |  | ||||||
|             # If we're clearing we just set the "flag" to True |  | ||||||
|             if action_direction == "clear": |  | ||||||
|                 pk_set = True |  | ||||||
|             thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}} |  | ||||||
|         return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs) |  | ||||||
|  | |||||||
| @ -1,210 +0,0 @@ | |||||||
| from unittest.mock import PropertyMock, patch |  | ||||||
|  |  | ||||||
| from django.apps import apps |  | ||||||
| from django.conf import settings |  | ||||||
| from django.urls import reverse |  | ||||||
| from rest_framework.test import APITestCase |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.events.utils import sanitize_item |  | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestEnterpriseAudit(APITestCase): |  | ||||||
|     """Test audit middleware""" |  | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |  | ||||||
|         self.user = create_test_admin_user() |  | ||||||
|  |  | ||||||
|     def test_import(self): |  | ||||||
|         """Ensure middleware is imported when app.ready is called""" |  | ||||||
|         # Revert import swap |  | ||||||
|         orig_import = "authentik.events.middleware.AuditMiddleware" |  | ||||||
|         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" |  | ||||||
|         settings.MIDDLEWARE = [orig_import if x == new_import else x for x in settings.MIDDLEWARE] |  | ||||||
|         # Re-call ready() |  | ||||||
|         apps.get_app_config("authentik_enterprise_audit").ready() |  | ||||||
|         self.assertIn( |  | ||||||
|             "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", |  | ||||||
|         PropertyMock(return_value=True), |  | ||||||
|     ) |  | ||||||
|     def test_create(self): |  | ||||||
|         """Test create audit log""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         username = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:user-list"), |  | ||||||
|             data={"name": generate_id(), "username": username, "groups": [], "path": "foo"}, |  | ||||||
|         ) |  | ||||||
|         user = User.objects.get(username=username) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         events = Event.objects.filter( |  | ||||||
|             action=EventAction.MODEL_CREATED, |  | ||||||
|             context__model__model_name="user", |  | ||||||
|             context__model__app="authentik_core", |  | ||||||
|             context__model__pk=user.pk, |  | ||||||
|         ) |  | ||||||
|         event = events.first() |  | ||||||
|         self.assertIsNotNone(event) |  | ||||||
|         self.assertIsNotNone(event.context["diff"]) |  | ||||||
|         diff = event.context["diff"] |  | ||||||
|         self.assertEqual( |  | ||||||
|             diff, |  | ||||||
|             { |  | ||||||
|                 "name": { |  | ||||||
|                     "new_value": user.name, |  | ||||||
|                     "previous_value": None, |  | ||||||
|                 }, |  | ||||||
|                 "path": {"new_value": "foo", "previous_value": None}, |  | ||||||
|                 "type": {"new_value": "internal", "previous_value": None}, |  | ||||||
|                 "uuid": { |  | ||||||
|                     "new_value": user.uuid.hex, |  | ||||||
|                     "previous_value": None, |  | ||||||
|                 }, |  | ||||||
|                 "email": {"new_value": "", "previous_value": None}, |  | ||||||
|                 "username": { |  | ||||||
|                     "new_value": user.username, |  | ||||||
|                     "previous_value": None, |  | ||||||
|                 }, |  | ||||||
|                 "is_active": {"new_value": True, "previous_value": None}, |  | ||||||
|                 "attributes": {"new_value": {}, "previous_value": None}, |  | ||||||
|                 "date_joined": { |  | ||||||
|                     "new_value": sanitize_item(user.date_joined), |  | ||||||
|                     "previous_value": None, |  | ||||||
|                 }, |  | ||||||
|                 "first_name": {"new_value": "", "previous_value": None}, |  | ||||||
|                 "id": {"new_value": user.pk, "previous_value": None}, |  | ||||||
|                 "last_name": {"new_value": "", "previous_value": None}, |  | ||||||
|                 "password": {"new_value": "********************", "previous_value": None}, |  | ||||||
|                 "password_change_date": { |  | ||||||
|                     "new_value": sanitize_item(user.password_change_date), |  | ||||||
|                     "previous_value": None, |  | ||||||
|                 }, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", |  | ||||||
|         PropertyMock(return_value=True), |  | ||||||
|     ) |  | ||||||
|     def test_update(self): |  | ||||||
|         """Test update audit log""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         user = create_test_admin_user() |  | ||||||
|         current_name = user.name |  | ||||||
|         new_name = generate_id() |  | ||||||
|         response = self.client.patch( |  | ||||||
|             reverse("authentik_api:user-detail", kwargs={"pk": user.id}), |  | ||||||
|             data={"name": new_name}, |  | ||||||
|         ) |  | ||||||
|         user.refresh_from_db() |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         events = Event.objects.filter( |  | ||||||
|             action=EventAction.MODEL_UPDATED, |  | ||||||
|             context__model__model_name="user", |  | ||||||
|             context__model__app="authentik_core", |  | ||||||
|             context__model__pk=user.pk, |  | ||||||
|         ) |  | ||||||
|         event = events.first() |  | ||||||
|         self.assertIsNotNone(event) |  | ||||||
|         self.assertIsNotNone(event.context["diff"]) |  | ||||||
|         diff = event.context["diff"] |  | ||||||
|         self.assertEqual( |  | ||||||
|             diff, |  | ||||||
|             { |  | ||||||
|                 "name": { |  | ||||||
|                     "new_value": new_name, |  | ||||||
|                     "previous_value": current_name, |  | ||||||
|                 }, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", |  | ||||||
|         PropertyMock(return_value=True), |  | ||||||
|     ) |  | ||||||
|     def test_delete(self): |  | ||||||
|         """Test delete audit log""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         user = create_test_admin_user() |  | ||||||
|         response = self.client.delete( |  | ||||||
|             reverse("authentik_api:user-detail", kwargs={"pk": user.id}), |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 204) |  | ||||||
|         events = Event.objects.filter( |  | ||||||
|             action=EventAction.MODEL_DELETED, |  | ||||||
|             context__model__model_name="user", |  | ||||||
|             context__model__app="authentik_core", |  | ||||||
|             context__model__pk=user.pk, |  | ||||||
|         ) |  | ||||||
|         event = events.first() |  | ||||||
|         self.assertIsNotNone(event) |  | ||||||
|         self.assertNotIn("diff", event.context) |  | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", |  | ||||||
|         PropertyMock(return_value=True), |  | ||||||
|     ) |  | ||||||
|     def test_m2m_add(self): |  | ||||||
|         """Test m2m add audit log""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         user = create_test_admin_user() |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:group-add-user", kwargs={"pk": group.group_uuid}), |  | ||||||
|             data={ |  | ||||||
|                 "pk": user.pk, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 204) |  | ||||||
|         events = Event.objects.filter( |  | ||||||
|             action=EventAction.MODEL_UPDATED, |  | ||||||
|             context__model__model_name="group", |  | ||||||
|             context__model__app="authentik_core", |  | ||||||
|             context__model__pk=group.pk.hex, |  | ||||||
|         ) |  | ||||||
|         event = events.first() |  | ||||||
|         self.assertIsNotNone(event) |  | ||||||
|         self.assertIsNotNone(event.context["diff"]) |  | ||||||
|         diff = event.context["diff"] |  | ||||||
|         self.assertEqual( |  | ||||||
|             diff, |  | ||||||
|             {"users": {"add": [user.pk]}}, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", |  | ||||||
|         PropertyMock(return_value=True), |  | ||||||
|     ) |  | ||||||
|     def test_m2m_remove(self): |  | ||||||
|         """Test m2m remove audit log""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         user = create_test_admin_user() |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:group-remove-user", kwargs={"pk": group.group_uuid}), |  | ||||||
|             data={ |  | ||||||
|                 "pk": user.pk, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 204) |  | ||||||
|         events = Event.objects.filter( |  | ||||||
|             action=EventAction.MODEL_UPDATED, |  | ||||||
|             context__model__model_name="group", |  | ||||||
|             context__model__app="authentik_core", |  | ||||||
|             context__model__pk=group.pk.hex, |  | ||||||
|         ) |  | ||||||
|         event = events.first() |  | ||||||
|         self.assertIsNotNone(event) |  | ||||||
|         self.assertIsNotNone(event.context["diff"]) |  | ||||||
|         diff = event.context["diff"] |  | ||||||
|         self.assertEqual( |  | ||||||
|             diff, |  | ||||||
|             {"users": {"remove": [user.pk]}}, |  | ||||||
|         ) |  | ||||||
| @ -1,18 +0,0 @@ | |||||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_providers_rac", "0001_squashed_0003_alter_connectiontoken_options_and_more"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="connectiontoken", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -201,7 +201,10 @@ class ConnectionToken(ExpiringModel): | |||||||
|         return settings |         return settings | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"RAC Connection token {self.session_id} to {self.provider_id}/{self.endpoint_id}" |         return ( | ||||||
|  |             f"RAC Connection token {self.session.user} to " | ||||||
|  |             f"{self.endpoint.provider.name}/{self.endpoint.name}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("RAC Connection token") |         verbose_name = _("RAC Connection token") | ||||||
|  | |||||||
| @ -12,6 +12,7 @@ from rest_framework.fields import ( | |||||||
|     ChoiceField, |     ChoiceField, | ||||||
|     DateTimeField, |     DateTimeField, | ||||||
|     FloatField, |     FloatField, | ||||||
|  |     ListField, | ||||||
|     SerializerMethodField, |     SerializerMethodField, | ||||||
| ) | ) | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -20,7 +21,6 @@ from rest_framework.serializers import ModelSerializer | |||||||
| from rest_framework.viewsets import ReadOnlyModelViewSet | from rest_framework.viewsets import ReadOnlyModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.events.logs import LogEventSerializer |  | ||||||
| from authentik.events.models import SystemTask, TaskStatus | from authentik.events.models import SystemTask, TaskStatus | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
| @ -39,7 +39,7 @@ class SystemTaskSerializer(ModelSerializer): | |||||||
|     duration = FloatField(read_only=True) |     duration = FloatField(read_only=True) | ||||||
|  |  | ||||||
|     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) |     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) | ||||||
|     messages = LogEventSerializer(many=True) |     messages = ListField(child=CharField()) | ||||||
|  |  | ||||||
|     def get_full_name(self, instance: SystemTask) -> str: |     def get_full_name(self, instance: SystemTask) -> str: | ||||||
|         """Get full name with UID""" |         """Get full name with UID""" | ||||||
|  | |||||||
| @ -1,82 +0,0 @@ | |||||||
| from collections.abc import Generator |  | ||||||
| from contextlib import contextmanager |  | ||||||
| from dataclasses import dataclass, field |  | ||||||
| from datetime import datetime |  | ||||||
| from typing import Any |  | ||||||
|  |  | ||||||
| from django.utils.timezone import now |  | ||||||
| from rest_framework.fields import CharField, ChoiceField, DateTimeField, DictField |  | ||||||
| from structlog import configure, get_config |  | ||||||
| from structlog.stdlib import NAME_TO_LEVEL, ProcessorFormatter |  | ||||||
| from structlog.testing import LogCapture |  | ||||||
| from structlog.types import EventDict |  | ||||||
|  |  | ||||||
| from authentik.core.api.utils import PassiveSerializer |  | ||||||
| from authentik.events.utils import sanitize_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass() |  | ||||||
| class LogEvent: |  | ||||||
|  |  | ||||||
|     event: str |  | ||||||
|     log_level: str |  | ||||||
|     logger: str |  | ||||||
|     timestamp: datetime = field(default_factory=now) |  | ||||||
|     attributes: dict[str, Any] = field(default_factory=dict) |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def from_event_dict(item: EventDict) -> "LogEvent": |  | ||||||
|         event = item.pop("event") |  | ||||||
|         log_level = item.pop("level").lower() |  | ||||||
|         timestamp = datetime.fromisoformat(item.pop("timestamp")) |  | ||||||
|         item.pop("pid", None) |  | ||||||
|         # Sometimes log entries have both `level` and `log_level` set, but `level` is always set |  | ||||||
|         item.pop("log_level", None) |  | ||||||
|         return LogEvent( |  | ||||||
|             event, log_level, item.pop("logger"), timestamp, attributes=sanitize_dict(item) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class LogEventSerializer(PassiveSerializer): |  | ||||||
|     """Single log message with all context logged.""" |  | ||||||
|  |  | ||||||
|     timestamp = DateTimeField() |  | ||||||
|     log_level = ChoiceField(choices=tuple((x, x) for x in NAME_TO_LEVEL.keys())) |  | ||||||
|     logger = CharField() |  | ||||||
|     event = CharField() |  | ||||||
|     attributes = DictField() |  | ||||||
|  |  | ||||||
|     # TODO(2024.6?): This is a migration helper to return a correct API response for logs that |  | ||||||
|     # have been saved in an older format (mostly just list[str] with just the messages) |  | ||||||
|     def to_representation(self, instance): |  | ||||||
|         if isinstance(instance, str): |  | ||||||
|             instance = LogEvent(instance, "", "") |  | ||||||
|         elif isinstance(instance, list): |  | ||||||
|             instance = [LogEvent(x, "", "") for x in instance] |  | ||||||
|         return super().to_representation(instance) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @contextmanager |  | ||||||
| def capture_logs(log_default_output=True) -> Generator[list[LogEvent], None, None]: |  | ||||||
|     """Capture log entries created""" |  | ||||||
|     logs = [] |  | ||||||
|     cap = LogCapture() |  | ||||||
|     # Modify `_Configuration.default_processors` set via `configure` but always |  | ||||||
|     # keep the list instance intact to not break references held by bound |  | ||||||
|     # loggers. |  | ||||||
|     processors: list = get_config()["processors"] |  | ||||||
|     old_processors = processors.copy() |  | ||||||
|     try: |  | ||||||
|         # clear processors list and use LogCapture for testing |  | ||||||
|         if ProcessorFormatter.wrap_for_formatter in processors: |  | ||||||
|             processors.remove(ProcessorFormatter.wrap_for_formatter) |  | ||||||
|         processors.append(cap) |  | ||||||
|         configure(processors=processors) |  | ||||||
|         yield logs |  | ||||||
|         for raw_log in cap.entries: |  | ||||||
|             logs.append(LogEvent.from_event_dict(raw_log)) |  | ||||||
|     finally: |  | ||||||
|         # remove LogCapture and restore original processors |  | ||||||
|         processors.clear() |  | ||||||
|         processors.extend(old_processors) |  | ||||||
|         configure(processors=processors) |  | ||||||
| @ -1,8 +1,6 @@ | |||||||
| """Events middleware""" | """Events middleware""" | ||||||
|  |  | ||||||
| from collections.abc import Callable | from collections.abc import Callable | ||||||
| from contextlib import contextmanager |  | ||||||
| from contextvars import ContextVar |  | ||||||
| from functools import partial | from functools import partial | ||||||
| from threading import Thread | from threading import Thread | ||||||
| from typing import Any | from typing import Any | ||||||
| @ -33,9 +31,6 @@ IGNORED_MODELS = tuple( | |||||||
|     ) |     ) | ||||||
| ) | ) | ||||||
|  |  | ||||||
| _CTX_OVERWRITE_USER = ContextVar[User | None]("authentik_events_log_overwrite_user", default=None) |  | ||||||
| _CTX_IGNORE = ContextVar[bool]("authentik_events_log_ignore", default=False) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def should_log_model(model: Model) -> bool: | def should_log_model(model: Model) -> bool: | ||||||
|     """Return true if operation on `model` should be logged""" |     """Return true if operation on `model` should be logged""" | ||||||
| @ -49,28 +44,6 @@ def should_log_m2m(model: Model) -> bool: | |||||||
|     return False |     return False | ||||||
|  |  | ||||||
|  |  | ||||||
| @contextmanager |  | ||||||
| def audit_overwrite_user(user: User): |  | ||||||
|     """Overwrite user being logged for model AuditMiddleware. Commonly used |  | ||||||
|     for example in flows where a pending user is given, but the request is not authenticated yet""" |  | ||||||
|     _CTX_OVERWRITE_USER.set(user) |  | ||||||
|     try: |  | ||||||
|         yield |  | ||||||
|     finally: |  | ||||||
|         _CTX_OVERWRITE_USER.set(None) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @contextmanager |  | ||||||
| def audit_ignore(): |  | ||||||
|     """Ignore model operations in the block. Useful for objects which need to be modified |  | ||||||
|     but are not excluded (e.g. WebAuthn devices)""" |  | ||||||
|     _CTX_IGNORE.set(True) |  | ||||||
|     try: |  | ||||||
|         yield |  | ||||||
|     finally: |  | ||||||
|         _CTX_IGNORE.set(False) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EventNewThread(Thread): | class EventNewThread(Thread): | ||||||
|     """Create Event in background thread""" |     """Create Event in background thread""" | ||||||
|  |  | ||||||
| @ -110,32 +83,26 @@ class AuditMiddleware: | |||||||
|  |  | ||||||
|         self.anonymous_user = get_anonymous_user() |         self.anonymous_user = get_anonymous_user() | ||||||
|  |  | ||||||
|     def get_user(self, request: HttpRequest) -> User: |  | ||||||
|         user = _CTX_OVERWRITE_USER.get() |  | ||||||
|         if user: |  | ||||||
|             return user |  | ||||||
|         user = getattr(request, "user", self.anonymous_user) |  | ||||||
|         if not user.is_authenticated: |  | ||||||
|             self._ensure_fallback_user() |  | ||||||
|             return self.anonymous_user |  | ||||||
|         return user |  | ||||||
|  |  | ||||||
|     def connect(self, request: HttpRequest): |     def connect(self, request: HttpRequest): | ||||||
|         """Connect signal for automatic logging""" |         """Connect signal for automatic logging""" | ||||||
|  |         self._ensure_fallback_user() | ||||||
|  |         user = getattr(request, "user", self.anonymous_user) | ||||||
|  |         if not user.is_authenticated: | ||||||
|  |             user = self.anonymous_user | ||||||
|         if not hasattr(request, "request_id"): |         if not hasattr(request, "request_id"): | ||||||
|             return |             return | ||||||
|         post_save.connect( |         post_save.connect( | ||||||
|             partial(self.post_save_handler, request=request), |             partial(self.post_save_handler, user=user, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
|         pre_delete.connect( |         pre_delete.connect( | ||||||
|             partial(self.pre_delete_handler, request=request), |             partial(self.pre_delete_handler, user=user, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
|         m2m_changed.connect( |         m2m_changed.connect( | ||||||
|             partial(self.m2m_changed_handler, request=request), |             partial(self.m2m_changed_handler, user=user, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
| @ -180,6 +147,7 @@ class AuditMiddleware: | |||||||
|  |  | ||||||
|     def post_save_handler( |     def post_save_handler( | ||||||
|         self, |         self, | ||||||
|  |         user: User, | ||||||
|         request: HttpRequest, |         request: HttpRequest, | ||||||
|         sender, |         sender, | ||||||
|         instance: Model, |         instance: Model, | ||||||
| @ -190,22 +158,16 @@ class AuditMiddleware: | |||||||
|         """Signal handler for all object's post_save""" |         """Signal handler for all object's post_save""" | ||||||
|         if not should_log_model(instance): |         if not should_log_model(instance): | ||||||
|             return |             return | ||||||
|         if _CTX_IGNORE.get(): |  | ||||||
|             return |  | ||||||
|         user = self.get_user(request) |  | ||||||
|  |  | ||||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED |         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||||
|         thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) |         thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) | ||||||
|         thread.kwargs.update(thread_kwargs or {}) |         thread.kwargs.update(thread_kwargs or {}) | ||||||
|         thread.run() |         thread.run() | ||||||
|  |  | ||||||
|     def pre_delete_handler(self, request: HttpRequest, sender, instance: Model, **_): |     def pre_delete_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||||
|         """Signal handler for all object's pre_delete""" |         """Signal handler for all object's pre_delete""" | ||||||
|         if not should_log_model(instance):  # pragma: no cover |         if not should_log_model(instance):  # pragma: no cover | ||||||
|             return |             return | ||||||
|         if _CTX_IGNORE.get(): |  | ||||||
|             return |  | ||||||
|         user = self.get_user(request) |  | ||||||
|  |  | ||||||
|         EventNewThread( |         EventNewThread( | ||||||
|             EventAction.MODEL_DELETED, |             EventAction.MODEL_DELETED, | ||||||
| @ -215,27 +177,17 @@ class AuditMiddleware: | |||||||
|         ).run() |         ).run() | ||||||
|  |  | ||||||
|     def m2m_changed_handler( |     def m2m_changed_handler( | ||||||
|         self, |         self, user: User, request: HttpRequest, sender, instance: Model, action: str, **_ | ||||||
|         request: HttpRequest, |  | ||||||
|         sender, |  | ||||||
|         instance: Model, |  | ||||||
|         action: str, |  | ||||||
|         thread_kwargs: dict | None = None, |  | ||||||
|         **_, |  | ||||||
|     ): |     ): | ||||||
|         """Signal handler for all object's m2m_changed""" |         """Signal handler for all object's m2m_changed""" | ||||||
|         if action not in ["pre_add", "pre_remove", "post_clear"]: |         if action not in ["pre_add", "pre_remove", "post_clear"]: | ||||||
|             return |             return | ||||||
|         if not should_log_m2m(instance): |         if not should_log_m2m(instance): | ||||||
|             return |             return | ||||||
|         if _CTX_IGNORE.get(): |  | ||||||
|             return |  | ||||||
|         user = self.get_user(request) |  | ||||||
|  |  | ||||||
|         EventNewThread( |         EventNewThread( | ||||||
|             EventAction.MODEL_UPDATED, |             EventAction.MODEL_UPDATED, | ||||||
|             request, |             request, | ||||||
|             user=user, |             user=user, | ||||||
|             model=model_to_dict(instance), |             model=model_to_dict(instance), | ||||||
|             **thread_kwargs, |  | ||||||
|         ).run() |         ).run() | ||||||
|  | |||||||
| @ -1,21 +0,0 @@ | |||||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ( |  | ||||||
|             "authentik_events", |  | ||||||
|             "0004_systemtask_squashed_0005_remove_systemtask_finish_timestamp_and_more", |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="systemtask", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,39 +0,0 @@ | |||||||
| # Generated by Django 5.0.4 on 2024-04-15 16:17 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_events", "0006_alter_systemtask_expires"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="event", |  | ||||||
|             index=models.Index(fields=["action"], name="authentik_e_action_9a9dd9_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="event", |  | ||||||
|             index=models.Index(fields=["user"], name="authentik_e_user_1be48d_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="event", |  | ||||||
|             index=models.Index(fields=["app"], name="authentik_e_app_6a05ce_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="event", |  | ||||||
|             index=models.Index(fields=["created"], name="authentik_e_created_6f0834_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="event", |  | ||||||
|             index=models.Index(fields=["client_ip"], name="authentik_e_client__51f4dd_idx"), |  | ||||||
|         ), |  | ||||||
|         migrations.AddIndex( |  | ||||||
|             model_name="event", |  | ||||||
|             index=models.Index( |  | ||||||
|                 models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -305,16 +305,6 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Event") |         verbose_name = _("Event") | ||||||
|         verbose_name_plural = _("Events") |         verbose_name_plural = _("Events") | ||||||
|         indexes = [ |  | ||||||
|             models.Index(fields=["action"]), |  | ||||||
|             models.Index(fields=["user"]), |  | ||||||
|             models.Index(fields=["app"]), |  | ||||||
|             models.Index(fields=["created"]), |  | ||||||
|             models.Index(fields=["client_ip"]), |  | ||||||
|             models.Index( |  | ||||||
|                 models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TransportMode(models.TextChoices): | class TransportMode(models.TextChoices): | ||||||
| @ -556,7 +546,7 @@ class Notification(SerializerModel): | |||||||
|             if len(self.body) > NOTIFICATION_SUMMARY_LENGTH |             if len(self.body) > NOTIFICATION_SUMMARY_LENGTH | ||||||
|             else self.body |             else self.body | ||||||
|         ) |         ) | ||||||
|         return f"Notification for user {self.user_id}: {body_trunc}" |         return f"Notification for user {self.user}: {body_trunc}" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Notification") |         verbose_name = _("Notification") | ||||||
|  | |||||||
| @ -9,7 +9,6 @@ from django.utils.translation import gettext_lazy as _ | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from tenant_schemas_celery.task import TenantTask | from tenant_schemas_celery.task import TenantTask | ||||||
|  |  | ||||||
| from authentik.events.logs import LogEvent |  | ||||||
| from authentik.events.models import Event, EventAction, TaskStatus | from authentik.events.models import Event, EventAction, TaskStatus | ||||||
| from authentik.events.models import SystemTask as DBSystemTask | from authentik.events.models import SystemTask as DBSystemTask | ||||||
| from authentik.events.utils import sanitize_item | from authentik.events.utils import sanitize_item | ||||||
| @ -25,7 +24,7 @@ class SystemTask(TenantTask): | |||||||
|     save_on_success: bool |     save_on_success: bool | ||||||
|  |  | ||||||
|     _status: TaskStatus |     _status: TaskStatus | ||||||
|     _messages: list[LogEvent] |     _messages: list[str] | ||||||
|  |  | ||||||
|     _uid: str | None |     _uid: str | None | ||||||
|     # Precise start time from perf_counter |     # Precise start time from perf_counter | ||||||
| @ -45,20 +44,15 @@ class SystemTask(TenantTask): | |||||||
|         """Set UID, so in the case of an unexpected error its saved correctly""" |         """Set UID, so in the case of an unexpected error its saved correctly""" | ||||||
|         self._uid = uid |         self._uid = uid | ||||||
|  |  | ||||||
|     def set_status(self, status: TaskStatus, *messages: LogEvent): |     def set_status(self, status: TaskStatus, *messages: str): | ||||||
|         """Set result for current run, will overwrite previous result.""" |         """Set result for current run, will overwrite previous result.""" | ||||||
|         self._status = status |         self._status = status | ||||||
|         self._messages = list(messages) |         self._messages = messages | ||||||
|         for idx, msg in enumerate(self._messages): |  | ||||||
|             if not isinstance(msg, LogEvent): |  | ||||||
|                 self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info") |  | ||||||
|  |  | ||||||
|     def set_error(self, exception: Exception): |     def set_error(self, exception: Exception): | ||||||
|         """Set result to error and save exception""" |         """Set result to error and save exception""" | ||||||
|         self._status = TaskStatus.ERROR |         self._status = TaskStatus.ERROR | ||||||
|         self._messages = [ |         self._messages = [exception_to_string(exception)] | ||||||
|             LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error") |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     def before_start(self, task_id, args, kwargs): |     def before_start(self, task_id, args, kwargs): | ||||||
|         self._start_precise = perf_counter() |         self._start_precise = perf_counter() | ||||||
| @ -104,7 +98,8 @@ class SystemTask(TenantTask): | |||||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): |     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||||
|         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) |         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||||
|         if not self._status: |         if not self._status: | ||||||
|             self.set_error(exc) |             self._status = TaskStatus.ERROR | ||||||
|  |             self._messages = exception_to_string(exc) | ||||||
|         DBSystemTask.objects.update_or_create( |         DBSystemTask.objects.update_or_create( | ||||||
|             name=self.__name__, |             name=self.__name__, | ||||||
|             uid=self._uid, |             uid=self._uid, | ||||||
| @ -119,7 +114,7 @@ class SystemTask(TenantTask): | |||||||
|                 "task_call_kwargs": sanitize_item(kwargs), |                 "task_call_kwargs": sanitize_item(kwargs), | ||||||
|                 "status": self._status, |                 "status": self._status, | ||||||
|                 "messages": sanitize_item(self._messages), |                 "messages": sanitize_item(self._messages), | ||||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours + 3), |                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||||
|                 "expiring": True, |                 "expiring": True, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -3,11 +3,9 @@ | |||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Application, Token, TokenIntents | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
| from authentik.events.middleware import audit_ignore, audit_overwrite_user |  | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestEventsMiddleware(APITestCase): | class TestEventsMiddleware(APITestCase): | ||||||
| @ -17,100 +15,35 @@ class TestEventsMiddleware(APITestCase): | |||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.user = create_test_admin_user() |         self.user = create_test_admin_user() | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|         Event.objects.all().delete() |  | ||||||
|  |  | ||||||
|     def test_create(self): |     def test_create(self): | ||||||
|         """Test model creation event""" |         """Test model creation event""" | ||||||
|         uid = generate_id() |  | ||||||
|         self.client.post( |         self.client.post( | ||||||
|             reverse("authentik_api:application-list"), |             reverse("authentik_api:application-list"), | ||||||
|             data={"name": uid, "slug": uid}, |             data={"name": "test-create", "slug": "test-create"}, | ||||||
|  |         ) | ||||||
|  |         self.assertTrue(Application.objects.filter(name="test-create").exists()) | ||||||
|  |         self.assertTrue( | ||||||
|  |             Event.objects.filter( | ||||||
|  |                 action=EventAction.MODEL_CREATED, | ||||||
|  |                 context__model__model_name="application", | ||||||
|  |                 context__model__app="authentik_core", | ||||||
|  |                 context__model__name="test-create", | ||||||
|  |             ).exists() | ||||||
|         ) |         ) | ||||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) |  | ||||||
|         event = Event.objects.filter( |  | ||||||
|             action=EventAction.MODEL_CREATED, |  | ||||||
|             context__model__model_name="application", |  | ||||||
|             context__model__app="authentik_core", |  | ||||||
|             context__model__name=uid, |  | ||||||
|         ).first() |  | ||||||
|         self.assertIsNotNone(event) |  | ||||||
|  |  | ||||||
|     def test_delete(self): |     def test_delete(self): | ||||||
|         """Test model creation event""" |         """Test model creation event""" | ||||||
|         uid = generate_id() |         Application.objects.create(name="test-delete", slug="test-delete") | ||||||
|         Application.objects.create(name=uid, slug=uid) |         self.client.delete( | ||||||
|         self.client.delete(reverse("authentik_api:application-detail", kwargs={"slug": uid})) |             reverse("authentik_api:application-detail", kwargs={"slug": "test-delete"}) | ||||||
|  |         ) | ||||||
|         self.assertFalse(Application.objects.filter(name="test").exists()) |         self.assertFalse(Application.objects.filter(name="test").exists()) | ||||||
|         self.assertTrue( |         self.assertTrue( | ||||||
|             Event.objects.filter( |             Event.objects.filter( | ||||||
|                 action=EventAction.MODEL_DELETED, |                 action=EventAction.MODEL_DELETED, | ||||||
|                 context__model__model_name="application", |                 context__model__model_name="application", | ||||||
|                 context__model__app="authentik_core", |                 context__model__app="authentik_core", | ||||||
|                 context__model__name=uid, |                 context__model__name="test-delete", | ||||||
|             ).exists() |             ).exists() | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_audit_ignore(self): |  | ||||||
|         """Test audit_ignore context manager""" |  | ||||||
|         uid = generate_id() |  | ||||||
|         with audit_ignore(): |  | ||||||
|             self.client.post( |  | ||||||
|                 reverse("authentik_api:application-list"), |  | ||||||
|                 data={"name": uid, "slug": uid}, |  | ||||||
|             ) |  | ||||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) |  | ||||||
|         self.assertFalse( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, |  | ||||||
|                 context__model__model_name="application", |  | ||||||
|                 context__model__app="authentik_core", |  | ||||||
|                 context__model__name=uid, |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_audit_overwrite_user(self): |  | ||||||
|         """Test audit_overwrite_user context manager""" |  | ||||||
|         uid = generate_id() |  | ||||||
|         new_user = create_test_admin_user() |  | ||||||
|         with audit_overwrite_user(new_user): |  | ||||||
|             self.client.post( |  | ||||||
|                 reverse("authentik_api:application-list"), |  | ||||||
|                 data={"name": uid, "slug": uid}, |  | ||||||
|             ) |  | ||||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) |  | ||||||
|         self.assertTrue( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, |  | ||||||
|                 context__model__model_name="application", |  | ||||||
|                 context__model__app="authentik_core", |  | ||||||
|                 context__model__name=uid, |  | ||||||
|                 user__username=new_user.username, |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_create_with_api(self): |  | ||||||
|         """Test model creation event (with API token auth)""" |  | ||||||
|         self.client.logout() |  | ||||||
|         token = Token.objects.create(user=self.user, intent=TokenIntents.INTENT_API, expiring=False) |  | ||||||
|         uid = generate_id() |  | ||||||
|         self.client.post( |  | ||||||
|             reverse("authentik_api:application-list"), |  | ||||||
|             data={"name": uid, "slug": uid}, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) |  | ||||||
|         event = Event.objects.filter( |  | ||||||
|             action=EventAction.MODEL_CREATED, |  | ||||||
|             context__model__model_name="application", |  | ||||||
|             context__model__app="authentik_core", |  | ||||||
|             context__model__name=uid, |  | ||||||
|         ).first() |  | ||||||
|         self.assertIsNotNone(event) |  | ||||||
|         self.assertEqual( |  | ||||||
|             event.user, |  | ||||||
|             { |  | ||||||
|                 "pk": self.user.pk, |  | ||||||
|                 "email": self.user.email, |  | ||||||
|                 "username": self.user.username, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  | |||||||
| @ -1,35 +0,0 @@ | |||||||
| """authentik event models tests""" |  | ||||||
|  |  | ||||||
| from collections.abc import Callable |  | ||||||
|  |  | ||||||
| from django.db.models import Model |  | ||||||
| from django.test import TestCase |  | ||||||
|  |  | ||||||
| from authentik.core.models import default_token_key |  | ||||||
| from authentik.lib.utils.reflection import get_apps |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestModels(TestCase): |  | ||||||
|     """Test Models""" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def model_tester_factory(test_model: type[Model]) -> Callable: |  | ||||||
|     """Test models' __str__ and __repr__""" |  | ||||||
|  |  | ||||||
|     def tester(self: TestModels): |  | ||||||
|         allowed = 0 |  | ||||||
|         # Token-like objects need to lookup the current tenant to get the default token length |  | ||||||
|         for field in test_model._meta.fields: |  | ||||||
|             if field.default == default_token_key: |  | ||||||
|                 allowed += 1 |  | ||||||
|         with self.assertNumQueries(allowed): |  | ||||||
|             str(test_model()) |  | ||||||
|         with self.assertNumQueries(allowed): |  | ||||||
|             repr(test_model()) |  | ||||||
|  |  | ||||||
|     return tester |  | ||||||
|  |  | ||||||
|  |  | ||||||
| for app in get_apps(): |  | ||||||
|     for model in app.get_models(): |  | ||||||
|         setattr(TestModels, f"test_{app.label}_{model.__name__}", model_tester_factory(model)) |  | ||||||
| @ -47,4 +47,3 @@ class FlowStageBindingViewSet(UsedByMixin, ModelViewSet): | |||||||
|     filterset_fields = "__all__" |     filterset_fields = "__all__" | ||||||
|     search_fields = ["stage__name"] |     search_fields = ["stage__name"] | ||||||
|     ordering = ["order"] |     ordering = ["order"] | ||||||
|     ordering_fields = ["order", "stage__name"] |  | ||||||
|  | |||||||
| @ -7,7 +7,7 @@ from django.utils.translation import gettext as _ | |||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import BooleanField, CharField, ReadOnlyField | from rest_framework.fields import BooleanField, CharField, DictField, ListField, ReadOnlyField | ||||||
| from rest_framework.parsers import MultiPartParser | from rest_framework.parsers import MultiPartParser | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| @ -19,7 +19,7 @@ from authentik.blueprints.v1.exporter import FlowExporter | |||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import CacheSerializer, LinkSerializer, PassiveSerializer | from authentik.core.api.utils import CacheSerializer, LinkSerializer, PassiveSerializer | ||||||
| from authentik.events.logs import LogEventSerializer | from authentik.events.utils import sanitize_dict | ||||||
| from authentik.flows.api.flows_diagram import FlowDiagram, FlowDiagramSerializer | from authentik.flows.api.flows_diagram import FlowDiagram, FlowDiagramSerializer | ||||||
| from authentik.flows.exceptions import FlowNonApplicableException | from authentik.flows.exceptions import FlowNonApplicableException | ||||||
| from authentik.flows.models import Flow | from authentik.flows.models import Flow | ||||||
| @ -107,7 +107,7 @@ class FlowSetSerializer(FlowSerializer): | |||||||
| class FlowImportResultSerializer(PassiveSerializer): | class FlowImportResultSerializer(PassiveSerializer): | ||||||
|     """Logs of an attempted flow import""" |     """Logs of an attempted flow import""" | ||||||
|  |  | ||||||
|     logs = LogEventSerializer(many=True, read_only=True) |     logs = ListField(child=DictField(), read_only=True) | ||||||
|     success = BooleanField(read_only=True) |     success = BooleanField(read_only=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -184,7 +184,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|  |  | ||||||
|         importer = Importer.from_string(file.read().decode()) |         importer = Importer.from_string(file.read().decode()) | ||||||
|         valid, logs = importer.validate() |         valid, logs = importer.validate() | ||||||
|         import_response.initial_data["logs"] = [LogEventSerializer(log).data for log in logs] |         import_response.initial_data["logs"] = [sanitize_dict(log) for log in logs] | ||||||
|         import_response.initial_data["success"] = valid |         import_response.initial_data["success"] = valid | ||||||
|         import_response.is_valid() |         import_response.is_valid() | ||||||
|         if not valid: |         if not valid: | ||||||
| @ -278,7 +278,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||||
|     def execute(self, request: Request, slug: str): |     def execute(self, request: Request, _slug: str): | ||||||
|         """Execute flow for current user""" |         """Execute flow for current user""" | ||||||
|         # Because we pre-plan the flow here, and not in the planner, we need to manually clear |         # Because we pre-plan the flow here, and not in the planner, we need to manually clear | ||||||
|         # the history of the inspector |         # the history of the inspector | ||||||
|  | |||||||
| @ -31,9 +31,10 @@ class AuthentikFlowsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Flows" |     verbose_name = "authentik Flows" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def import_related(self): |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def load_stages(self): | ||||||
|  |         """Ensure all stages are loaded""" | ||||||
|         from authentik.flows.models import Stage |         from authentik.flows.models import Stage | ||||||
|  |  | ||||||
|         for stage in all_subclasses(Stage): |         for stage in all_subclasses(Stage): | ||||||
|             _ = stage().view |             _ = stage().view | ||||||
|         return super().import_related() |  | ||||||
|  | |||||||
| @ -59,11 +59,11 @@ class FlowPlan: | |||||||
|     markers: list[StageMarker] = field(default_factory=list) |     markers: list[StageMarker] = field(default_factory=list) | ||||||
|  |  | ||||||
|     def append_stage(self, stage: Stage, marker: StageMarker | None = None): |     def append_stage(self, stage: Stage, marker: StageMarker | None = None): | ||||||
|         """Append `stage` to the end of the plan, optionally with stage marker""" |         """Append `stage` to all stages, optionally with stage marker""" | ||||||
|         return self.append(FlowStageBinding(stage=stage), marker) |         return self.append(FlowStageBinding(stage=stage), marker) | ||||||
|  |  | ||||||
|     def append(self, binding: FlowStageBinding, marker: StageMarker | None = None): |     def append(self, binding: FlowStageBinding, marker: StageMarker | None = None): | ||||||
|         """Append `stage` to the end of the plan, optionally with stage marker""" |         """Append `stage` to all stages, optionally with stage marker""" | ||||||
|         self.bindings.append(binding) |         self.bindings.append(binding) | ||||||
|         self.markers.append(marker or StageMarker()) |         self.markers.append(marker or StageMarker()) | ||||||
|  |  | ||||||
| @ -203,8 +203,7 @@ class FlowPlanner: | |||||||
|                 "f(plan): building plan", |                 "f(plan): building plan", | ||||||
|             ) |             ) | ||||||
|             plan = self._build_plan(user, request, default_context) |             plan = self._build_plan(user, request, default_context) | ||||||
|             if self.use_cache: |             cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT) | ||||||
|                 cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT) |  | ||||||
|             if not plan.bindings and not self.allow_empty_flows: |             if not plan.bindings and not self.allow_empty_flows: | ||||||
|                 raise EmptyFlowException() |                 raise EmptyFlowException() | ||||||
|             return plan |             return plan | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ from rest_framework.test import APITestCase | |||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
| from authentik.flows.api.stages import StageSerializer, StageViewSet | from authentik.flows.api.stages import StageSerializer, StageViewSet | ||||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage | from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.stages.dummy.models import DummyStage | from authentik.stages.dummy.models import DummyStage | ||||||
| @ -102,21 +101,3 @@ class TestFlowsAPI(APITestCase): | |||||||
|             reverse("authentik_api:stage-types"), |             reverse("authentik_api:stage-types"), | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|     def test_execute(self): |  | ||||||
|         """Test execute endpoint""" |  | ||||||
|         user = create_test_admin_user() |  | ||||||
|         self.client.force_login(user) |  | ||||||
|  |  | ||||||
|         flow = Flow.objects.create( |  | ||||||
|             name=generate_id(), |  | ||||||
|             slug=generate_id(), |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         FlowStageBinding.objects.create( |  | ||||||
|             target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0 |  | ||||||
|         ) |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse("authentik_api:flow-execute", kwargs={"slug": flow.slug}) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ from sentry_sdk.hub import Hub | |||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
|  | from authentik.brands.utils import cors_allow | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.events.models import Event, EventAction, cleanse_dict | from authentik.events.models import Event, EventAction, cleanse_dict | ||||||
| from authentik.flows.apps import HIST_FLOW_EXECUTION_STAGE_TIME | from authentik.flows.apps import HIST_FLOW_EXECUTION_STAGE_TIME | ||||||
| @ -155,6 +156,14 @@ class FlowExecutorView(APIView): | |||||||
|         return plan |         return plan | ||||||
|  |  | ||||||
|     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: |     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: | ||||||
|  |         response = self.dispatch_wrapper(request, flow_slug) | ||||||
|  |         origins = [] | ||||||
|  |         if request.brand.origin != "": | ||||||
|  |             origins.append(request.brand.origin) | ||||||
|  |         cors_allow(request, response, *origins) | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |     def dispatch_wrapper(self, request: HttpRequest, flow_slug: str) -> HttpResponse: | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
|             op="authentik.flow.executor.dispatch", description=self.flow.slug |             op="authentik.flow.executor.dispatch", description=self.flow.slug | ||||||
|         ) as span: |         ) as span: | ||||||
| @ -450,7 +459,7 @@ class FlowExecutorView(APIView): | |||||||
|         return to_stage_response(self.request, challenge_view.get(self.request)) |         return to_stage_response(self.request, challenge_view.get(self.request)) | ||||||
|  |  | ||||||
|     def cancel(self): |     def cancel(self): | ||||||
|         """Cancel current flow execution""" |         """Cancel current execution and return a redirect""" | ||||||
|         keys_to_delete = [ |         keys_to_delete = [ | ||||||
|             SESSION_KEY_APPLICATION_PRE, |             SESSION_KEY_APPLICATION_PRE, | ||||||
|             SESSION_KEY_PLAN, |             SESSION_KEY_PLAN, | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ from django.http import HttpRequest, HttpResponseNotFound | |||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from lxml import etree  # nosec | from lxml import etree  # nosec | ||||||
| from lxml.etree import Element, SubElement  # nosec | from lxml.etree import Element, SubElement  # nosec | ||||||
| from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout | from requests.exceptions import RequestException | ||||||
|  |  | ||||||
| from authentik.lib.config import get_path_from_dict | from authentik.lib.config import get_path_from_dict | ||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
| @ -23,8 +23,6 @@ if TYPE_CHECKING: | |||||||
| GRAVATAR_URL = "https://secure.gravatar.com" | GRAVATAR_URL = "https://secure.gravatar.com" | ||||||
| DEFAULT_AVATAR = static("dist/assets/images/user_default.png") | DEFAULT_AVATAR = static("dist/assets/images/user_default.png") | ||||||
| CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/" | CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/" | ||||||
| CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available" |  | ||||||
| GRAVATAR_STATUS_TTL_SECONDS = 60 * 60 * 8  # 8 Hours |  | ||||||
|  |  | ||||||
| SVG_XML_NS = "http://www.w3.org/2000/svg" | SVG_XML_NS = "http://www.w3.org/2000/svg" | ||||||
| SVG_NS_MAP = {None: SVG_XML_NS} | SVG_NS_MAP = {None: SVG_XML_NS} | ||||||
| @ -52,9 +50,6 @@ def avatar_mode_attribute(user: "User", mode: str) -> str | None: | |||||||
|  |  | ||||||
| def avatar_mode_gravatar(user: "User", mode: str) -> str | None: | def avatar_mode_gravatar(user: "User", mode: str) -> str | None: | ||||||
|     """Gravatar avatars""" |     """Gravatar avatars""" | ||||||
|     if not cache.get(CACHE_KEY_GRAVATAR_AVAILABLE, True): |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|     # gravatar uses md5 for their URLs, so md5 can't be avoided |     # gravatar uses md5 for their URLs, so md5 can't be avoided | ||||||
|     mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest()  # nosec |     mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest()  # nosec | ||||||
|     parameters = [("size", "158"), ("rating", "g"), ("default", "404")] |     parameters = [("size", "158"), ("rating", "g"), ("default", "404")] | ||||||
| @ -74,8 +69,6 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None: | |||||||
|             cache.set(full_key, None) |             cache.set(full_key, None) | ||||||
|             return None |             return None | ||||||
|         res.raise_for_status() |         res.raise_for_status() | ||||||
|     except (Timeout, ConnectionError, HTTPError): |  | ||||||
|         cache.set(CACHE_KEY_GRAVATAR_AVAILABLE, False, timeout=GRAVATAR_STATUS_TTL_SECONDS) |  | ||||||
|     except RequestException: |     except RequestException: | ||||||
|         return gravatar_url |         return gravatar_url | ||||||
|     cache.set(full_key, gravatar_url) |     cache.set(full_key, gravatar_url) | ||||||
|  | |||||||
| @ -14,7 +14,7 @@ from pathlib import Path | |||||||
| from sys import argv, stderr | from sys import argv, stderr | ||||||
| from time import time | from time import time | ||||||
| from typing import Any | from typing import Any | ||||||
| from urllib.parse import quote_plus, urlparse | from urllib.parse import urlparse | ||||||
|  |  | ||||||
| import yaml | import yaml | ||||||
| from django.conf import ImproperlyConfigured | from django.conf import ImproperlyConfigured | ||||||
| @ -331,26 +331,6 @@ class ConfigLoader: | |||||||
| CONFIG = ConfigLoader() | CONFIG = ConfigLoader() | ||||||
|  |  | ||||||
|  |  | ||||||
| def redis_url(db: int) -> str: |  | ||||||
|     """Helper to create a Redis URL for a specific database""" |  | ||||||
|     _redis_protocol_prefix = "redis://" |  | ||||||
|     _redis_tls_requirements = "" |  | ||||||
|     if CONFIG.get_bool("redis.tls", False): |  | ||||||
|         _redis_protocol_prefix = "rediss://" |  | ||||||
|         _redis_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}" |  | ||||||
|         if _redis_ca := CONFIG.get("redis.tls_ca_cert", None): |  | ||||||
|             _redis_tls_requirements += f"&ssl_ca_certs={_redis_ca}" |  | ||||||
|     _redis_url = ( |  | ||||||
|         f"{_redis_protocol_prefix}" |  | ||||||
|         f"{quote_plus(CONFIG.get('redis.username'))}:" |  | ||||||
|         f"{quote_plus(CONFIG.get('redis.password'))}@" |  | ||||||
|         f"{quote_plus(CONFIG.get('redis.host'))}:" |  | ||||||
|         f"{CONFIG.get_int('redis.port')}" |  | ||||||
|         f"/{db}{_redis_tls_requirements}" |  | ||||||
|     ) |  | ||||||
|     return _redis_url |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     if len(argv) < 2:  # noqa: PLR2004 |     if len(argv) < 2:  # noqa: PLR2004 | ||||||
|         print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder)) |         print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder)) | ||||||
|  | |||||||
| @ -35,7 +35,6 @@ redis: | |||||||
|   password: "" |   password: "" | ||||||
|   tls: false |   tls: false | ||||||
|   tls_reqs: "none" |   tls_reqs: "none" | ||||||
|   tls_ca_cert: null |  | ||||||
|  |  | ||||||
| # broker: | # broker: | ||||||
| #   url: "" | #   url: "" | ||||||
| @ -53,15 +52,12 @@ cache: | |||||||
|  |  | ||||||
| # result_backend: | # result_backend: | ||||||
| #   url: "" | #   url: "" | ||||||
| #   transport_options: "" |  | ||||||
|  |  | ||||||
| debug: false | debug: false | ||||||
| remote_debug: false | remote_debug: false | ||||||
|  |  | ||||||
| log_level: info | log_level: info | ||||||
|  |  | ||||||
| session_storage: cache |  | ||||||
|  |  | ||||||
| error_reporting: | error_reporting: | ||||||
|   enabled: false |   enabled: false | ||||||
|   sentry_dsn: https://151ba72610234c4c97c5bcff4e1cffd8@authentik.error-reporting.a7k.io/4504163677503489 |   sentry_dsn: https://151ba72610234c4c97c5bcff4e1cffd8@authentik.error-reporting.a7k.io/4504163677503489 | ||||||
| @ -114,6 +110,7 @@ events: | |||||||
|     asn: "/geoip/GeoLite2-ASN.mmdb" |     asn: "/geoip/GeoLite2-ASN.mmdb" | ||||||
|  |  | ||||||
| cert_discovery_dir: /certs | cert_discovery_dir: /certs | ||||||
|  | default_token_length: 60 | ||||||
|  |  | ||||||
| tenants: | tenants: | ||||||
|   enabled: false |   enabled: false | ||||||
|  | |||||||
| @ -2,11 +2,11 @@ | |||||||
|  |  | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
|  | from django.conf import settings | ||||||
| from requests.sessions import PreparedRequest, Session | from requests.sessions import PreparedRequest, Session | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik import get_full_version | from authentik import get_full_version | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -35,6 +35,6 @@ class DebugSession(Session): | |||||||
|  |  | ||||||
| def get_http_session() -> Session: | def get_http_session() -> Session: | ||||||
|     """Get a requests session with common headers""" |     """Get a requests session with common headers""" | ||||||
|     session = DebugSession() if CONFIG.get_bool("debug") else Session() |     session = DebugSession() if settings.DEBUG else Session() | ||||||
|     session.headers["User-Agent"] = authentik_user_agent() |     session.headers["User-Agent"] = authentik_user_agent() | ||||||
|     return session |     return session | ||||||
|  | |||||||
| @ -3,14 +3,12 @@ | |||||||
| import os | import os | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from tempfile import gettempdir |  | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  | from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
|  |  | ||||||
| SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def all_subclasses(cls, sort=True): | def all_subclasses(cls, sort=True): | ||||||
|     """Recursively return all subclassess of cls""" |     """Recursively return all subclassess of cls""" | ||||||
| @ -57,7 +55,7 @@ def get_env() -> str: | |||||||
|         return "dev" |         return "dev" | ||||||
|     if SERVICE_HOST_ENV_NAME in os.environ: |     if SERVICE_HOST_ENV_NAME in os.environ: | ||||||
|         return "kubernetes" |         return "kubernetes" | ||||||
|     if (Path(gettempdir()) / "authentik-mode").exists(): |     if Path("/tmp/authentik-mode").exists():  # nosec | ||||||
|         return "compose" |         return "compose" | ||||||
|     if "AK_APPLIANCE" in os.environ: |     if "AK_APPLIANCE" in os.environ: | ||||||
|         return os.environ["AK_APPLIANCE"] |         return os.environ["AK_APPLIANCE"] | ||||||
|  | |||||||
| @ -45,14 +45,14 @@ class AuthentikOutpostConfig(ManagedAppConfig): | |||||||
|                 outpost.managed = MANAGED_OUTPOST |                 outpost.managed = MANAGED_OUTPOST | ||||||
|                 outpost.save() |                 outpost.save() | ||||||
|                 return |                 return | ||||||
|             outpost, created = Outpost.objects.update_or_create( |             outpost, updated = Outpost.objects.update_or_create( | ||||||
|                 defaults={ |                 defaults={ | ||||||
|                     "type": OutpostType.PROXY, |                     "type": OutpostType.PROXY, | ||||||
|                     "name": MANAGED_OUTPOST_NAME, |                     "name": MANAGED_OUTPOST_NAME, | ||||||
|                 }, |                 }, | ||||||
|                 managed=MANAGED_OUTPOST, |                 managed=MANAGED_OUTPOST, | ||||||
|             ) |             ) | ||||||
|             if created: |             if updated: | ||||||
|                 if KubernetesServiceConnection.objects.exists(): |                 if KubernetesServiceConnection.objects.exists(): | ||||||
|                     outpost.service_connection = KubernetesServiceConnection.objects.first() |                     outpost.service_connection = KubernetesServiceConnection.objects.first() | ||||||
|                 elif DockerServiceConnection.objects.exists(): |                 elif DockerServiceConnection.objects.exists(): | ||||||
|  | |||||||
| @ -3,9 +3,9 @@ | |||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  |  | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  | from structlog.testing import capture_logs | ||||||
|  |  | ||||||
| from authentik import __version__, get_build_hash | from authentik import __version__, get_build_hash | ||||||
| from authentik.events.logs import LogEvent, capture_logs |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.sentry import SentryIgnoredException | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.outposts.models import ( | from authentik.outposts.models import ( | ||||||
| @ -63,21 +63,21 @@ class BaseController: | |||||||
|         """Called by scheduled task to reconcile deployment/service/etc""" |         """Called by scheduled task to reconcile deployment/service/etc""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def up_with_logs(self) -> list[LogEvent]: |     def up_with_logs(self) -> list[str]: | ||||||
|         """Call .up() but capture all log output and return it.""" |         """Call .up() but capture all log output and return it.""" | ||||||
|         with capture_logs() as logs: |         with capture_logs() as logs: | ||||||
|             self.up() |             self.up() | ||||||
|         return logs |         return [x["event"] for x in logs] | ||||||
|  |  | ||||||
|     def down(self): |     def down(self): | ||||||
|         """Handler to delete everything we've created""" |         """Handler to delete everything we've created""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def down_with_logs(self) -> list[LogEvent]: |     def down_with_logs(self) -> list[str]: | ||||||
|         """Call .down() but capture all log output and return it.""" |         """Call .down() but capture all log output and return it.""" | ||||||
|         with capture_logs() as logs: |         with capture_logs() as logs: | ||||||
|             self.down() |             self.down() | ||||||
|         return logs |         return [x["event"] for x in logs] | ||||||
|  |  | ||||||
|     def __enter__(self): |     def __enter__(self): | ||||||
|         return self |         return self | ||||||
|  | |||||||
| @ -9,10 +9,10 @@ from kubernetes.client.exceptions import OpenApiException | |||||||
| from kubernetes.config.config_exception import ConfigException | from kubernetes.config.config_exception import ConfigException | ||||||
| from kubernetes.config.incluster_config import load_incluster_config | from kubernetes.config.incluster_config import load_incluster_config | ||||||
| from kubernetes.config.kube_config import load_kube_config_from_dict | from kubernetes.config.kube_config import load_kube_config_from_dict | ||||||
|  | from structlog.testing import capture_logs | ||||||
| from urllib3.exceptions import HTTPError | from urllib3.exceptions import HTTPError | ||||||
| from yaml import dump_all | from yaml import dump_all | ||||||
|  |  | ||||||
| from authentik.events.logs import LogEvent, capture_logs |  | ||||||
| from authentik.outposts.controllers.base import BaseClient, BaseController, ControllerException | from authentik.outposts.controllers.base import BaseClient, BaseController, ControllerException | ||||||
| from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler | from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler | ||||||
| from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler | from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler | ||||||
| @ -91,7 +91,7 @@ class KubernetesController(BaseController): | |||||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: |         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||||
|             raise ControllerException(str(exc)) from exc |             raise ControllerException(str(exc)) from exc | ||||||
|  |  | ||||||
|     def up_with_logs(self) -> list[LogEvent]: |     def up_with_logs(self) -> list[str]: | ||||||
|         try: |         try: | ||||||
|             all_logs = [] |             all_logs = [] | ||||||
|             for reconcile_key in self.reconcile_order: |             for reconcile_key in self.reconcile_order: | ||||||
| @ -104,9 +104,7 @@ class KubernetesController(BaseController): | |||||||
|                         continue |                         continue | ||||||
|                     reconciler = reconciler_cls(self) |                     reconciler = reconciler_cls(self) | ||||||
|                     reconciler.up() |                     reconciler.up() | ||||||
|                 for log in logs: |                 all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs] | ||||||
|                     log.logger = reconcile_key.title() |  | ||||||
|                 all_logs.extend(logs) |  | ||||||
|             return all_logs |             return all_logs | ||||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: |         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||||
|             raise ControllerException(str(exc)) from exc |             raise ControllerException(str(exc)) from exc | ||||||
| @ -124,7 +122,7 @@ class KubernetesController(BaseController): | |||||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: |         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||||
|             raise ControllerException(str(exc)) from exc |             raise ControllerException(str(exc)) from exc | ||||||
|  |  | ||||||
|     def down_with_logs(self) -> list[LogEvent]: |     def down_with_logs(self) -> list[str]: | ||||||
|         try: |         try: | ||||||
|             all_logs = [] |             all_logs = [] | ||||||
|             for reconcile_key in self.reconcile_order: |             for reconcile_key in self.reconcile_order: | ||||||
| @ -137,9 +135,7 @@ class KubernetesController(BaseController): | |||||||
|                         continue |                         continue | ||||||
|                     reconciler = reconciler_cls(self) |                     reconciler = reconciler_cls(self) | ||||||
|                     reconciler.down() |                     reconciler.down() | ||||||
|                 for log in logs: |                 all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs] | ||||||
|                     log.logger = reconcile_key.title() |  | ||||||
|                 all_logs.extend(logs) |  | ||||||
|             return all_logs |             return all_logs | ||||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: |         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||||
|             raise ControllerException(str(exc)) from exc |             raise ControllerException(str(exc)) from exc | ||||||
|  | |||||||
| @ -149,8 +149,10 @@ def outpost_controller( | |||||||
|         if not controller_type: |         if not controller_type: | ||||||
|             return |             return | ||||||
|         with controller_type(outpost, outpost.service_connection) as controller: |         with controller_type(outpost, outpost.service_connection) as controller: | ||||||
|             LOGGER.debug("---------------Outpost Controller logs starting----------------") |  | ||||||
|             logs = getattr(controller, f"{action}_with_logs")() |             logs = getattr(controller, f"{action}_with_logs")() | ||||||
|  |             LOGGER.debug("---------------Outpost Controller logs starting----------------") | ||||||
|  |             for log in logs: | ||||||
|  |                 LOGGER.debug(log) | ||||||
|             LOGGER.debug("-----------------Outpost Controller logs end-------------------") |             LOGGER.debug("-----------------Outpost Controller logs end-------------------") | ||||||
|     except (ControllerException, ServiceConnectionInvalid) as exc: |     except (ControllerException, ServiceConnectionInvalid) as exc: | ||||||
|         self.set_error(exc) |         self.set_error(exc) | ||||||
|  | |||||||
| @ -1,11 +1,10 @@ | |||||||
| """Serializer for policy execution""" | """Serializer for policy execution""" | ||||||
|  |  | ||||||
| from rest_framework.fields import BooleanField, CharField, ListField | from rest_framework.fields import BooleanField, CharField, DictField, ListField | ||||||
| from rest_framework.relations import PrimaryKeyRelatedField | from rest_framework.relations import PrimaryKeyRelatedField | ||||||
|  |  | ||||||
| from authentik.core.api.utils import JSONDictField, PassiveSerializer | from authentik.core.api.utils import JSONDictField, PassiveSerializer | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.events.logs import LogEventSerializer |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PolicyTestSerializer(PassiveSerializer): | class PolicyTestSerializer(PassiveSerializer): | ||||||
| @ -20,4 +19,4 @@ class PolicyTestResultSerializer(PassiveSerializer): | |||||||
|  |  | ||||||
|     passing = BooleanField() |     passing = BooleanField() | ||||||
|     messages = ListField(child=CharField(), read_only=True) |     messages = ListField(child=CharField(), read_only=True) | ||||||
|     log_messages = LogEventSerializer(many=True, read_only=True) |     log_messages = ListField(child=DictField(), read_only=True) | ||||||
|  | |||||||
| @ -11,11 +11,12 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ModelSerializer, SerializerMethodField | from rest_framework.serializers import ModelSerializer, SerializerMethodField | ||||||
| from rest_framework.viewsets import GenericViewSet | from rest_framework.viewsets import GenericViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  | from structlog.testing import capture_logs | ||||||
|  |  | ||||||
| from authentik.core.api.applications import user_app_cache_key | from authentik.core.api.applications import user_app_cache_key | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer | from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer | ||||||
| from authentik.events.logs import LogEventSerializer, capture_logs | from authentik.events.utils import sanitize_dict | ||||||
| from authentik.lib.utils.reflection import all_subclasses | from authentik.lib.utils.reflection import all_subclasses | ||||||
| from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer | from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer | ||||||
| from authentik.policies.models import Policy, PolicyBinding | from authentik.policies.models import Policy, PolicyBinding | ||||||
| @ -165,9 +166,9 @@ class PolicyViewSet( | |||||||
|             result = proc.execute() |             result = proc.execute() | ||||||
|         log_messages = [] |         log_messages = [] | ||||||
|         for log in logs: |         for log in logs: | ||||||
|             if log.attributes.get("process", "") == "PolicyProcess": |             if log.get("process", "") == "PolicyProcess": | ||||||
|                 continue |                 continue | ||||||
|             log_messages.append(LogEventSerializer(log).data) |             log_messages.append(sanitize_dict(log)) | ||||||
|         result.log_messages = log_messages |         result.log_messages = log_messages | ||||||
|         response = PolicyTestResultSerializer(result) |         response = PolicyTestResultSerializer(result) | ||||||
|         return Response(response.data) |         return Response(response.data) | ||||||
|  | |||||||
| @ -39,7 +39,6 @@ class Migration(migrations.Migration): | |||||||
|                     ("authentik.sources.oauth", "authentik Sources.OAuth"), |                     ("authentik.sources.oauth", "authentik Sources.OAuth"), | ||||||
|                     ("authentik.sources.plex", "authentik Sources.Plex"), |                     ("authentik.sources.plex", "authentik Sources.Plex"), | ||||||
|                     ("authentik.sources.saml", "authentik Sources.SAML"), |                     ("authentik.sources.saml", "authentik Sources.SAML"), | ||||||
|                     ("authentik.sources.scim", "authentik Sources.SCIM"), |  | ||||||
|                     ("authentik.stages.authenticator_duo", "authentik Stages.Authenticator.Duo"), |                     ("authentik.stages.authenticator_duo", "authentik Stages.Authenticator.Duo"), | ||||||
|                     ("authentik.stages.authenticator_sms", "authentik Stages.Authenticator.SMS"), |                     ("authentik.stages.authenticator_sms", "authentik Stages.Authenticator.SMS"), | ||||||
|                     ( |                     ( | ||||||
|  | |||||||
| @ -13,7 +13,6 @@ from authentik.events.context_processors.base import get_context_processors | |||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from authentik.core.models import User |     from authentik.core.models import User | ||||||
|     from authentik.events.logs import LogEvent |  | ||||||
|     from authentik.policies.models import PolicyBinding |     from authentik.policies.models import PolicyBinding | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -75,7 +74,7 @@ class PolicyResult: | |||||||
|     source_binding: PolicyBinding | None |     source_binding: PolicyBinding | None | ||||||
|     source_results: list[PolicyResult] | None |     source_results: list[PolicyResult] | None | ||||||
|  |  | ||||||
|     log_messages: list[LogEvent] | None |     log_messages: list[dict] | None | ||||||
|  |  | ||||||
|     def __init__(self, passing: bool, *messages: str): |     def __init__(self, passing: bool, *messages: str): | ||||||
|         self.passing = passing |         self.passing = passing | ||||||
|  | |||||||
| @ -1,9 +1,9 @@ | |||||||
| """authentik oauth provider app config""" | """authentik oauth provider app config""" | ||||||
|  |  | ||||||
| from django.apps import AppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikProviderOAuth2Config(AppConfig): | class AuthentikProviderOAuth2Config(ManagedAppConfig): | ||||||
|     """authentik oauth provider app config""" |     """authentik oauth provider app config""" | ||||||
|  |  | ||||||
|     name = "authentik.providers.oauth2" |     name = "authentik.providers.oauth2" | ||||||
| @ -13,3 +13,4 @@ class AuthentikProviderOAuth2Config(AppConfig): | |||||||
|         "authentik.providers.oauth2.urls_root": "", |         "authentik.providers.oauth2.urls_root": "", | ||||||
|         "authentik.providers.oauth2.urls": "application/o/", |         "authentik.providers.oauth2.urls": "application/o/", | ||||||
|     } |     } | ||||||
|  |     default = True | ||||||
|  | |||||||
| @ -8,7 +8,6 @@ from django.http import HttpRequest | |||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  |  | ||||||
| from authentik.core.models import default_token_duration |  | ||||||
| from authentik.events.signals import get_login_event | from authentik.events.signals import get_login_event | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.constants import ( | from authentik.providers.oauth2.constants import ( | ||||||
| @ -88,9 +87,7 @@ class IDToken: | |||||||
|     ) -> "IDToken": |     ) -> "IDToken": | ||||||
|         """Create ID Token""" |         """Create ID Token""" | ||||||
|         id_token = IDToken(provider, token, **kwargs) |         id_token = IDToken(provider, token, **kwargs) | ||||||
|         id_token.exp = int( |         id_token.exp = int(token.expires.timestamp()) | ||||||
|             (token.expires if token.expires is not None else default_token_duration()).timestamp() |  | ||||||
|         ) |  | ||||||
|         id_token.iss = provider.get_issuer(request) |         id_token.iss = provider.get_issuer(request) | ||||||
|         id_token.aud = provider.client_id |         id_token.aud = provider.client_id | ||||||
|         id_token.claims = {} |         id_token.claims = {} | ||||||
|  | |||||||
| @ -1,36 +0,0 @@ | |||||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ( |  | ||||||
|             "authentik_providers_oauth2", |  | ||||||
|             "0017_accesstoken_session_id_authorizationcode_session_id_and_more", |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="accesstoken", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="authorizationcode", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="devicetoken", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="refreshtoken", |  | ||||||
|             name="expires", |  | ||||||
|             field=models.DateTimeField(default=None, null=True), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -326,7 +326,7 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel): | |||||||
|         verbose_name_plural = _("Authorization Codes") |         verbose_name_plural = _("Authorization Codes") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Authorization code for {self.provider_id} for user {self.user_id}" |         return f"Authorization code for {self.provider} for user {self.user}" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> Serializer: |     def serializer(self) -> Serializer: | ||||||
| @ -356,7 +356,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel): | |||||||
|         verbose_name_plural = _("OAuth2 Access Tokens") |         verbose_name_plural = _("OAuth2 Access Tokens") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Access Token for {self.provider_id} for user {self.user_id}" |         return f"Access Token for {self.provider} for user {self.user}" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def id_token(self) -> IDToken: |     def id_token(self) -> IDToken: | ||||||
| @ -399,7 +399,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): | |||||||
|         verbose_name_plural = _("OAuth2 Refresh Tokens") |         verbose_name_plural = _("OAuth2 Refresh Tokens") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Refresh Token for {self.provider_id} for user {self.user_id}" |         return f"Refresh Token for {self.provider} for user {self.user}" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def id_token(self) -> IDToken: |     def id_token(self) -> IDToken: | ||||||
| @ -443,4 +443,4 @@ class DeviceToken(ExpiringModel): | |||||||
|         verbose_name_plural = _("Device Tokens") |         verbose_name_plural = _("Device Tokens") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Device Token for {self.provider_id}" |         return f"Device Token for {self.provider}" | ||||||
|  | |||||||
							
								
								
									
										15
									
								
								authentik/providers/oauth2/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								authentik/providers/oauth2/signals.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | |||||||
|  | from hashlib import sha256 | ||||||
|  |  | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
|  | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
|  |  | ||||||
|  | from authentik.core.models import User | ||||||
|  | from authentik.providers.oauth2.models import AccessToken | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User, **_): | ||||||
|  |     """Revoke access tokens upon user logout""" | ||||||
|  |     hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest() | ||||||
|  |     AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete() | ||||||
| @ -4,10 +4,9 @@ from urllib.parse import urlencode | |||||||
|  |  | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  |  | ||||||
| from authentik.core.models import Application, Group | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.models import PolicyBinding |  | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||||
| @ -78,23 +77,3 @@ class TesOAuth2DeviceInit(OAuthTestCase): | |||||||
|             + "?" |             + "?" | ||||||
|             + urlencode({QS_KEY_CODE: token.user_code}), |             + urlencode({QS_KEY_CODE: token.user_code}), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_device_init_denied(self): |  | ||||||
|         """Test device init""" |  | ||||||
|         group = Group.objects.create(name="foo") |  | ||||||
|         PolicyBinding.objects.create( |  | ||||||
|             group=group, |  | ||||||
|             target=self.application, |  | ||||||
|             order=0, |  | ||||||
|         ) |  | ||||||
|         token = DeviceToken.objects.create( |  | ||||||
|             user_code="foo", |  | ||||||
|             provider=self.provider, |  | ||||||
|         ) |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse("authentik_providers_oauth2_root:device-login") |  | ||||||
|             + "?" |  | ||||||
|             + urlencode({QS_KEY_CODE: token.user_code}) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         self.assertIn(b"Permission denied", res.content) |  | ||||||
|  | |||||||
| @ -10,7 +10,6 @@ from jwt import PyJWKSet | |||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_cert, create_test_flow | from authentik.core.tests.utils import create_test_cert, create_test_flow | ||||||
| from authentik.crypto.builder import PrivateKeyAlg |  | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider | from authentik.providers.oauth2.models import OAuth2Provider | ||||||
| @ -83,7 +82,7 @@ class TestJWKS(OAuthTestCase): | |||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris="http://local.invalid", | ||||||
|             signing_key=create_test_cert(PrivateKeyAlg.ECDSA), |             signing_key=create_test_cert(use_ec_private_key=True), | ||||||
|         ) |         ) | ||||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) |         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|  | |||||||
| @ -208,7 +208,7 @@ class TestToken(OAuthTestCase): | |||||||
|                 "token_type": TOKEN_TYPE, |                 "token_type": TOKEN_TYPE, | ||||||
|                 "expires_in": 3600, |                 "expires_in": 3600, | ||||||
|                 "id_token": provider.encode( |                 "id_token": provider.encode( | ||||||
|                     access.id_token.to_dict(), |                     refresh.id_token.to_dict(), | ||||||
|                 ), |                 ), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
| @ -267,7 +267,7 @@ class TestToken(OAuthTestCase): | |||||||
|                 "token_type": TOKEN_TYPE, |                 "token_type": TOKEN_TYPE, | ||||||
|                 "expires_in": 3600, |                 "expires_in": 3600, | ||||||
|                 "id_token": provider.encode( |                 "id_token": provider.encode( | ||||||
|                     access.id_token.to_dict(), |                     refresh.id_token.to_dict(), | ||||||
|                 ), |                 ), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -4,11 +4,9 @@ import re | |||||||
| from base64 import b64decode | from base64 import b64decode | ||||||
| from binascii import Error | from binascii import Error | ||||||
| from typing import Any | from typing import Any | ||||||
| from urllib.parse import urlparse |  | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse, JsonResponse | from django.http import HttpRequest, HttpResponse, JsonResponse | ||||||
| from django.http.response import HttpResponseRedirect | from django.http.response import HttpResponseRedirect | ||||||
| from django.utils.cache import patch_vary_headers |  | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.middleware import CTX_AUTH_VIA, KEY_USER | from authentik.core.middleware import CTX_AUTH_VIA, KEY_USER | ||||||
| @ -30,49 +28,6 @@ class TokenResponse(JsonResponse): | |||||||
|         self["Pragma"] = "no-cache" |         self["Pragma"] = "no-cache" | ||||||
|  |  | ||||||
|  |  | ||||||
| def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: str): |  | ||||||
|     """Add headers to permit CORS requests from allowed_origins, with or without credentials, |  | ||||||
|     with any headers.""" |  | ||||||
|     origin = request.META.get("HTTP_ORIGIN") |  | ||||||
|     if not origin: |  | ||||||
|         return response |  | ||||||
|  |  | ||||||
|     # OPTIONS requests don't have an authorization header -> hence |  | ||||||
|     # we can't extract the provider this request is for |  | ||||||
|     # so for options requests we allow the calling origin without checking |  | ||||||
|     allowed = request.method == "OPTIONS" |  | ||||||
|     received_origin = urlparse(origin) |  | ||||||
|     for allowed_origin in allowed_origins: |  | ||||||
|         url = urlparse(allowed_origin) |  | ||||||
|         if ( |  | ||||||
|             received_origin.scheme == url.scheme |  | ||||||
|             and received_origin.hostname == url.hostname |  | ||||||
|             and received_origin.port == url.port |  | ||||||
|         ): |  | ||||||
|             allowed = True |  | ||||||
|     if not allowed: |  | ||||||
|         LOGGER.warning( |  | ||||||
|             "CORS: Origin is not an allowed origin", |  | ||||||
|             requested=received_origin, |  | ||||||
|             allowed=allowed_origins, |  | ||||||
|         ) |  | ||||||
|         return response |  | ||||||
|  |  | ||||||
|     # From the CORS spec: The string "*" cannot be used for a resource that supports credentials. |  | ||||||
|     response["Access-Control-Allow-Origin"] = origin |  | ||||||
|     patch_vary_headers(response, ["Origin"]) |  | ||||||
|     response["Access-Control-Allow-Credentials"] = "true" |  | ||||||
|  |  | ||||||
|     if request.method == "OPTIONS": |  | ||||||
|         if "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" in request.META: |  | ||||||
|             response["Access-Control-Allow-Headers"] = request.META[ |  | ||||||
|                 "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" |  | ||||||
|             ] |  | ||||||
|         response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" |  | ||||||
|  |  | ||||||
|     return response |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def extract_access_token(request: HttpRequest) -> str | None: | def extract_access_token(request: HttpRequest) -> str | None: | ||||||
|     """ |     """ | ||||||
|     Get the access token using Authorization Request Header Field method. |     Get the access token using Authorization Request Header Field method. | ||||||
|  | |||||||
| @ -11,11 +11,10 @@ from django.views.decorators.csrf import csrf_exempt | |||||||
| from rest_framework.throttling import AnonRateThrottle | from rest_framework.throttling import AnonRateThrottle | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import Application |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -38,9 +37,7 @@ class DeviceView(View): | |||||||
|         ).first() |         ).first() | ||||||
|         if not provider: |         if not provider: | ||||||
|             return HttpResponseBadRequest() |             return HttpResponseBadRequest() | ||||||
|         try: |         if not get_application(provider): | ||||||
|             _ = provider.application |  | ||||||
|         except Application.DoesNotExist: |  | ||||||
|             return HttpResponseBadRequest() |             return HttpResponseBadRequest() | ||||||
|         self.provider = provider |         self.provider = provider | ||||||
|         self.client_id = client_id |         self.client_id = client_id | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	