Compare commits
	
		
			12 Commits
		
	
	
		
			version/20
			...
			core/b2c-i
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 195091ed3b | |||
| 4de3f1f4b8 | |||
| af4f1b3421 | |||
| 77b816ad51 | |||
| b28dd485a0 | |||
| 4701389745 | |||
| 0d0097e956 | |||
| b42eb0706d | |||
| 3afe386e18 | |||
| 34dd9c0b63 | |||
| b2f2fd241d | |||
| 828f477548 | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2024.4.3 | ||||
| current_version = 2024.2.2 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
| @ -21,8 +21,6 @@ optional_value = final | ||||
|  | ||||
| [bumpversion:file:schema.yml] | ||||
|  | ||||
| [bumpversion:file:blueprints/schema.json] | ||||
|  | ||||
| [bumpversion:file:authentik/__init__.py] | ||||
|  | ||||
| [bumpversion:file:internal/constants/constants.go] | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/FUNDING.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/FUNDING.yml
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| custom: https://goauthentik.io/pricing/ | ||||
| github: [BeryJu] | ||||
|  | ||||
| @ -12,7 +12,7 @@ should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower() | ||||
| branch_name = os.environ["GITHUB_REF"] | ||||
| if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||
|     branch_name = os.environ["GITHUB_HEAD_REF"] | ||||
| safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-").replace("'", "-") | ||||
| safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-") | ||||
|  | ||||
| image_names = os.getenv("IMAGE_NAME").split(",") | ||||
| image_arch = os.getenv("IMAGE_ARCH") or None | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -16,25 +16,25 @@ runs: | ||||
|         sudo apt-get update | ||||
|         sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext | ||||
|     - name: Setup python and restore poetry | ||||
|       uses: actions/setup-python@v5 | ||||
|       uses: actions/setup-python@v4 | ||||
|       with: | ||||
|         python-version-file: "pyproject.toml" | ||||
|         cache: "poetry" | ||||
|     - name: Setup node | ||||
|       uses: actions/setup-node@v4 | ||||
|       uses: actions/setup-node@v3 | ||||
|       with: | ||||
|         node-version-file: web/package.json | ||||
|         cache: "npm" | ||||
|         cache-dependency-path: web/package-lock.json | ||||
|     - name: Setup go | ||||
|       uses: actions/setup-go@v5 | ||||
|       uses: actions/setup-go@v4 | ||||
|       with: | ||||
|         go-version-file: "go.mod" | ||||
|     - name: Setup dependencies | ||||
|       shell: bash | ||||
|       run: | | ||||
|         export PSQL_TAG=${{ inputs.postgresql_version }} | ||||
|         docker compose -f .github/actions/setup/docker-compose.yml up -d | ||||
|         docker-compose -f .github/actions/setup/docker-compose.yml up -d | ||||
|         poetry install | ||||
|         cd web && npm ci | ||||
|     - name: Generate config | ||||
|  | ||||
							
								
								
									
										65
									
								
								.github/workflows/api-py-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										65
									
								
								.github/workflows/api-py-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,65 +0,0 @@ | ||||
| name: authentik-api-py-publish | ||||
| on: | ||||
|   push: | ||||
|     branches: [main] | ||||
|     paths: | ||||
|       - "schema.yml" | ||||
|   workflow_dispatch: | ||||
| jobs: | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       id-token: write | ||||
|     steps: | ||||
|       - id: generate_token | ||||
|         uses: tibdex/github-app-token@v2 | ||||
|         with: | ||||
|           app_id: ${{ secrets.GH_APP_ID }} | ||||
|           private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           token: ${{ steps.generate_token.outputs.token }} | ||||
|       - name: Install poetry & deps | ||||
|         shell: bash | ||||
|         run: | | ||||
|           pipx install poetry || true | ||||
|           sudo apt-get update | ||||
|           sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext | ||||
|       - name: Setup python and restore poetry | ||||
|         uses: actions/setup-python@v5 | ||||
|         with: | ||||
|           python-version-file: "pyproject.toml" | ||||
|           cache: "poetry" | ||||
|       - name: Generate API Client | ||||
|         run: make gen-client-py | ||||
|       - name: Publish package | ||||
|         working-directory: gen-py-api/ | ||||
|         run: | | ||||
|           poetry build | ||||
|       - name: Publish package to PyPI | ||||
|         uses: pypa/gh-action-pypi-publish@release/v1 | ||||
|         with: | ||||
|           packages-dir: gen-py-api/dist/ | ||||
|       # We can't easily upgrade the API client being used due to poetry being poetry | ||||
|       # so we'll have to rely on dependabot | ||||
|       # - name: Upgrade / | ||||
|       #   run: | | ||||
|       #     export VERSION=$(cd gen-py-api && poetry version -s) | ||||
|       #     poetry add "authentik_client=$VERSION" --allow-prereleases --lock | ||||
|       # - uses: peter-evans/create-pull-request@v6 | ||||
|       #   id: cpr | ||||
|       #   with: | ||||
|       #     token: ${{ steps.generate_token.outputs.token }} | ||||
|       #     branch: update-root-api-client | ||||
|       #     commit-message: "root: bump API Client version" | ||||
|       #     title: "root: bump API Client version" | ||||
|       #     body: "root: bump API Client version" | ||||
|       #     delete-branch: true | ||||
|       #     signoff: true | ||||
|       #     # ID from https://api.github.com/users/authentik-automation[bot] | ||||
|       #     author: authentik-automation[bot] <135050075+authentik-automation[bot]@users.noreply.github.com> | ||||
|       # - uses: peter-evans/enable-pull-request-automerge@v3 | ||||
|       #   with: | ||||
|       #     token: ${{ steps.generate_token.outputs.token }} | ||||
|       #     pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} | ||||
|       #     merge-method: squash | ||||
							
								
								
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -160,8 +160,6 @@ jobs: | ||||
|             glob: tests/e2e/test_provider_ldap* tests/e2e/test_source_ldap* | ||||
|           - name: radius | ||||
|             glob: tests/e2e/test_provider_radius* | ||||
|           - name: scim | ||||
|             glob: tests/e2e/test_source_scim* | ||||
|           - name: flows | ||||
|             glob: tests/e2e/test_flows* | ||||
|     steps: | ||||
| @ -170,7 +168,7 @@ jobs: | ||||
|         uses: ./.github/actions/setup | ||||
|       - name: Setup e2e env (chrome, etc) | ||||
|         run: | | ||||
|           docker compose -f tests/e2e/docker-compose.yml up -d | ||||
|           docker-compose -f tests/e2e/docker-compose.yml up -d | ||||
|       - id: cache-web | ||||
|         uses: actions/cache@v4 | ||||
|         with: | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							| @ -34,13 +34,6 @@ jobs: | ||||
|       - name: Eslint | ||||
|         working-directory: ${{ matrix.project }}/ | ||||
|         run: npm run lint | ||||
|   lint-lockfile: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - working-directory: web/ | ||||
|         run: | | ||||
|           [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ] | ||||
|   lint-build: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
| @ -102,7 +95,6 @@ jobs: | ||||
|         run: npm run lit-analyse | ||||
|   ci-web-mark: | ||||
|     needs: | ||||
|       - lint-lockfile | ||||
|       - lint-eslint | ||||
|       - lint-prettier | ||||
|       - lint-lit-analyse | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -12,13 +12,6 @@ on: | ||||
|       - version-* | ||||
|  | ||||
| jobs: | ||||
|   lint-lockfile: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - working-directory: website/ | ||||
|         run: | | ||||
|           [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ] | ||||
|   lint-prettier: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
| @ -69,7 +62,6 @@ jobs: | ||||
|         run: npm run ${{ matrix.job }} | ||||
|   ci-website-mark: | ||||
|     needs: | ||||
|       - lint-lockfile | ||||
|       - lint-prettier | ||||
|       - test | ||||
|       - build | ||||
|  | ||||
							
								
								
									
										43
									
								
								.github/workflows/gen-update-webauthn-mds.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										43
									
								
								.github/workflows/gen-update-webauthn-mds.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,43 +0,0 @@ | ||||
| name: authentik-gen-update-webauthn-mds | ||||
| on: | ||||
|   workflow_dispatch: | ||||
|   schedule: | ||||
|     - cron: '30 1 1,15 * *' | ||||
|  | ||||
| env: | ||||
|   POSTGRES_DB: authentik | ||||
|   POSTGRES_USER: authentik | ||||
|   POSTGRES_PASSWORD: "EK-5jnKfjrGRm<77" | ||||
|  | ||||
| jobs: | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - id: generate_token | ||||
|         uses: tibdex/github-app-token@v2 | ||||
|         with: | ||||
|           app_id: ${{ secrets.GH_APP_ID }} | ||||
|           private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           token: ${{ steps.generate_token.outputs.token }} | ||||
|       - name: Setup authentik env | ||||
|         uses: ./.github/actions/setup | ||||
|       - run: poetry run ak update_webauthn_mds | ||||
|       - uses: peter-evans/create-pull-request@v6 | ||||
|         id: cpr | ||||
|         with: | ||||
|           token: ${{ steps.generate_token.outputs.token }} | ||||
|           branch: update-fido-mds-client | ||||
|           commit-message: "stages/authenticator_webauthn: Update FIDO MDS3 & Passkey aaguid blobs" | ||||
|           title: "stages/authenticator_webauthn: Update FIDO MDS3 & Passkey aaguid blobs" | ||||
|           body: "stages/authenticator_webauthn: Update FIDO MDS3 & Passkey aaguid blobs" | ||||
|           delete-branch: true | ||||
|           signoff: true | ||||
|           # ID from https://api.github.com/users/authentik-automation[bot] | ||||
|           author: authentik-automation[bot] <135050075+authentik-automation[bot]@users.noreply.github.com> | ||||
|       - uses: peter-evans/enable-pull-request-automerge@v3 | ||||
|         with: | ||||
|           token: ${{ steps.generate_token.outputs.token }} | ||||
|           pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} | ||||
|           merge-method: squash | ||||
							
								
								
									
										8
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -157,10 +157,10 @@ jobs: | ||||
|         run: | | ||||
|           echo "PG_PASS=$(openssl rand -base64 32)" >> .env | ||||
|           echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env | ||||
|           docker compose pull -q | ||||
|           docker compose up --no-start | ||||
|           docker compose start postgresql redis | ||||
|           docker compose run -u root server test-all | ||||
|           docker-compose pull -q | ||||
|           docker-compose up --no-start | ||||
|           docker-compose start postgresql redis | ||||
|           docker-compose run -u root server test-all | ||||
|   sentry-release: | ||||
|     needs: | ||||
|       - build-server | ||||
|  | ||||
							
								
								
									
										6
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							| @ -21,9 +21,9 @@ jobs: | ||||
|           docker build -t testing:latest . | ||||
|           echo "AUTHENTIK_IMAGE=testing" >> .env | ||||
|           echo "AUTHENTIK_TAG=latest" >> .env | ||||
|           docker compose up --no-start | ||||
|           docker compose start postgresql redis | ||||
|           docker compose run -u root server test-all | ||||
|           docker-compose up --no-start | ||||
|           docker-compose start postgresql redis | ||||
|           docker-compose run -u root server test-all | ||||
|       - id: generate_token | ||||
|         uses: tibdex/github-app-token@v2 | ||||
|         with: | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/repo-stale.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/repo-stale.yml
									
									
									
									
										vendored
									
									
								
							| @ -23,7 +23,7 @@ jobs: | ||||
|           repo-token: ${{ steps.generate_token.outputs.token }} | ||||
|           days-before-stale: 60 | ||||
|           days-before-close: 7 | ||||
|           exempt-issue-labels: pinned,security,pr_wanted,enhancement,bug/confirmed,enhancement/confirmed,question,status/reviewing | ||||
|           exempt-issue-labels: pinned,security,pr_wanted,enhancement,bug/confirmed,enhancement/confirmed,question | ||||
|           stale-issue-label: wontfix | ||||
|           stale-issue-message: > | ||||
|             This issue has been automatically marked as stale because it has not had | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| name: authentik-api-ts-publish | ||||
| name: authentik-web-api-publish | ||||
| on: | ||||
|   push: | ||||
|     branches: [main] | ||||
							
								
								
									
										10
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								Dockerfile
									
									
									
									
									
								
							| @ -38,7 +38,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api | ||||
| RUN npm run build | ||||
|  | ||||
| # Stage 3: Build go proxy | ||||
| FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.2-bookworm AS go-builder | ||||
| FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.1-bookworm AS go-builder | ||||
|  | ||||
| ARG TARGETOS | ||||
| ARG TARGETARCH | ||||
| @ -70,10 +70,10 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \ | ||||
|     GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server | ||||
|  | ||||
| # Stage 4: MaxMind GeoIP | ||||
| FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip | ||||
| FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v6.1 as geoip | ||||
|  | ||||
| ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN" | ||||
| ENV GEOIPUPDATE_VERBOSE="1" | ||||
| ENV GEOIPUPDATE_VERBOSE="true" | ||||
| ENV GEOIPUPDATE_ACCOUNT_ID_FILE="/run/secrets/GEOIPUPDATE_ACCOUNT_ID" | ||||
| ENV GEOIPUPDATE_LICENSE_KEY_FILE="/run/secrets/GEOIPUPDATE_LICENSE_KEY" | ||||
|  | ||||
| @ -84,7 +84,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | ||||
|     /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||
|  | ||||
| # Stage 5: Python dependencies | ||||
| FROM docker.io/python:3.12.3-slim-bookworm AS python-deps | ||||
| FROM docker.io/python:3.12.2-slim-bookworm AS python-deps | ||||
|  | ||||
| WORKDIR /ak-root/poetry | ||||
|  | ||||
| @ -110,7 +110,7 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | ||||
|         poetry install --only=main --no-ansi --no-interaction --no-root" | ||||
|  | ||||
| # Stage 6: Run | ||||
| FROM docker.io/python:3.12.3-slim-bookworm AS final-image | ||||
| FROM docker.io/python:3.12.2-slim-bookworm AS final-image | ||||
|  | ||||
| ARG GIT_BUILD_HASH | ||||
| ARG VERSION | ||||
|  | ||||
							
								
								
									
										30
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								Makefile
									
									
									
									
									
								
							| @ -9,7 +9,6 @@ PY_SOURCES = authentik tests scripts lifecycle .github | ||||
| DOCKER_IMAGE ?= "authentik:test" | ||||
|  | ||||
| GEN_API_TS = "gen-ts-api" | ||||
| GEN_API_PY = "gen-py-api" | ||||
| GEN_API_GO = "gen-go-api" | ||||
|  | ||||
| pg_user := $(shell python -m authentik.lib.config postgresql.user 2>/dev/null) | ||||
| @ -48,10 +47,10 @@ test-go: | ||||
| test-docker:  ## Run all tests in a docker-compose | ||||
| 	echo "PG_PASS=$(openssl rand -base64 32)" >> .env | ||||
| 	echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env | ||||
| 	docker compose pull -q | ||||
| 	docker compose up --no-start | ||||
| 	docker compose start postgresql redis | ||||
| 	docker compose run -u root server test-all | ||||
| 	docker-compose pull -q | ||||
| 	docker-compose up --no-start | ||||
| 	docker-compose start postgresql redis | ||||
| 	docker-compose run -u root server test-all | ||||
| 	rm -f .env | ||||
|  | ||||
| test: ## Run the server tests and produce a coverage report (locally) | ||||
| @ -65,7 +64,7 @@ lint-fix:  ## Lint and automatically fix errors in the python source code. Repor | ||||
| 	codespell -w $(CODESPELL_ARGS) | ||||
|  | ||||
| lint: ## Lint the python and golang sources | ||||
| 	bandit -r $(PY_SOURCES) -x web/node_modules -x tests/wdio/node_modules -x website/node_modules | ||||
| 	bandit -r $(PY_SOURCES) -x node_modules | ||||
| 	golangci-lint run -v | ||||
|  | ||||
| core-install: | ||||
| @ -138,10 +137,7 @@ gen-clean-ts:  ## Remove generated API client for Typescript | ||||
| gen-clean-go:  ## Remove generated API client for Go | ||||
| 	rm -rf ./${GEN_API_GO}/ | ||||
|  | ||||
| gen-clean-py:  ## Remove generated API client for Python | ||||
| 	rm -rf ./${GEN_API_PY}/ | ||||
|  | ||||
| gen-clean: gen-clean-ts gen-clean-go gen-clean-py  ## Remove generated API clients | ||||
| gen-clean: gen-clean-ts gen-clean-go  ## Remove generated API clients | ||||
|  | ||||
| gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescript into the authentik UI Application | ||||
| 	docker run \ | ||||
| @ -159,20 +155,6 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | ||||
| 	cd ./${GEN_API_TS} && npm i | ||||
| 	\cp -rf ./${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||
|  | ||||
| gen-client-py: gen-clean-py ## Build and install the authentik API for Python | ||||
| 	docker run \ | ||||
| 		--rm -v ${PWD}:/local \ | ||||
| 		--user ${UID}:${GID} \ | ||||
| 		docker.io/openapitools/openapi-generator-cli:v7.4.0 generate \ | ||||
| 		-i /local/schema.yml \ | ||||
| 		-g python \ | ||||
| 		-o /local/${GEN_API_PY} \ | ||||
| 		-c /local/scripts/api-py-config.yaml \ | ||||
| 		--additional-properties=packageVersion=${NPM_VERSION} \ | ||||
| 		--git-repo-id authentik \ | ||||
| 		--git-user-id goauthentik | ||||
| 	pip install ./${GEN_API_PY} | ||||
|  | ||||
| gen-client-go: gen-clean-go  ## Build and install the authentik API for Golang | ||||
| 	mkdir -p ./${GEN_API_GO} ./${GEN_API_GO}/templates | ||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./${GEN_API_GO}/config.yaml | ||||
|  | ||||
| @ -26,9 +26,9 @@ For bigger setups, there is a Helm Chart [here](https://github.com/goauthentik/h | ||||
| ## Screenshots | ||||
|  | ||||
| | Light                                                  | Dark                                                  | | ||||
| | ----------------------------------------------------------- | ---------------------------------------------------------- | | ||||
| |   |   | | ||||
| |  |  | | ||||
| | ------------------------------------------------------ | ----------------------------------------------------- | | ||||
| |   |   | | ||||
| |  |  | | ||||
|  | ||||
| ## Development | ||||
|  | ||||
|  | ||||
| @ -19,9 +19,9 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni | ||||
| (.x being the latest patch release for each version) | ||||
|  | ||||
| | Version | Supported | | ||||
| | --------- | --------- | | ||||
| | 2023.10.x | ✅        | | ||||
| | 2024.2.x  | ✅        | | ||||
| | --- | --- | | ||||
| | 2023.6.x | ✅ | | ||||
| | 2023.8.x | ✅ | | ||||
|  | ||||
| ## Reporting a Vulnerability | ||||
|  | ||||
| @ -32,7 +32,7 @@ To report a vulnerability, send an email to [security@goauthentik.io](mailto:se | ||||
| authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories: | ||||
|  | ||||
| | Score | Severity | | ||||
| | ---------- | -------- | | ||||
| | --- | --- | | ||||
| | 0.0 | None | | ||||
| | 0.1 – 3.9 | Low | | ||||
| | 4.0 – 6.9 | Medium | | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from os import environ | ||||
|  | ||||
| __version__ = "2024.4.3" | ||||
| __version__ = "2024.2.2" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -10,3 +10,26 @@ class AuthentikAPIConfig(AppConfig): | ||||
|     label = "authentik_api" | ||||
|     mountpoint = "api/" | ||||
|     verbose_name = "authentik API" | ||||
|  | ||||
|     def ready(self) -> None: | ||||
|         from drf_spectacular.extensions import OpenApiAuthenticationExtension | ||||
|  | ||||
|         from authentik.api.authentication import TokenAuthentication | ||||
|  | ||||
|         # Class is defined here as it needs to be created early enough that drf-spectacular will | ||||
|         # find it, but also won't cause any import issues | ||||
|  | ||||
|         class TokenSchema(OpenApiAuthenticationExtension): | ||||
|             """Auth schema""" | ||||
|  | ||||
|             target_class = TokenAuthentication | ||||
|             name = "authentik" | ||||
|  | ||||
|             def get_security_definition(self, auto_schema): | ||||
|                 """Auth schema""" | ||||
|                 return { | ||||
|                     "type": "apiKey", | ||||
|                     "in": "header", | ||||
|                     "name": "Authorization", | ||||
|                     "scheme": "bearer", | ||||
|                 } | ||||
|  | ||||
| @ -4,7 +4,6 @@ from hmac import compare_digest | ||||
| from typing import Any | ||||
|  | ||||
| from django.conf import settings | ||||
| from drf_spectacular.extensions import OpenApiAuthenticationExtension | ||||
| from rest_framework.authentication import BaseAuthentication, get_authorization_header | ||||
| from rest_framework.exceptions import AuthenticationFailed | ||||
| from rest_framework.request import Request | ||||
| @ -103,14 +102,3 @@ class TokenAuthentication(BaseAuthentication): | ||||
|             return None | ||||
|  | ||||
|         return (user, None)  # pragma: no cover | ||||
|  | ||||
|  | ||||
| class TokenSchema(OpenApiAuthenticationExtension): | ||||
|     """Auth schema""" | ||||
|  | ||||
|     target_class = TokenAuthentication | ||||
|     name = "authentik" | ||||
|  | ||||
|     def get_security_definition(self, auto_schema): | ||||
|         """Auth schema""" | ||||
|         return {"type": "http", "scheme": "bearer"} | ||||
|  | ||||
| @ -12,7 +12,6 @@ from drf_spectacular.settings import spectacular_settings | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from rest_framework.settings import api_settings | ||||
|  | ||||
| from authentik.api.apps import AuthentikAPIConfig | ||||
| from authentik.api.pagination import PAGINATION_COMPONENT_NAME, PAGINATION_SCHEMA | ||||
|  | ||||
|  | ||||
| @ -102,12 +101,3 @@ def postprocess_schema_responses(result, generator: SchemaGenerator, **kwargs): | ||||
|             comp = result["components"]["schemas"][component] | ||||
|             comp["additionalProperties"] = {} | ||||
|     return result | ||||
|  | ||||
|  | ||||
| def preprocess_schema_exclude_non_api(endpoints, **kwargs): | ||||
|     """Filter out all API Views which are not mounted under /api""" | ||||
|     return [ | ||||
|         (path, path_regex, method, callback) | ||||
|         for path, path_regex, method, callback in endpoints | ||||
|         if path.startswith("/" + AuthentikAPIConfig.mountpoint) | ||||
|     ] | ||||
|  | ||||
| @ -8,8 +8,6 @@ from django.apps import AppConfig | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| from authentik.root.signals import startup | ||||
|  | ||||
|  | ||||
| class ManagedAppConfig(AppConfig): | ||||
|     """Basic reconciliation logic for apps""" | ||||
| @ -25,12 +23,9 @@ class ManagedAppConfig(AppConfig): | ||||
|  | ||||
|     def ready(self) -> None: | ||||
|         self.import_related() | ||||
|         startup.connect(self._on_startup_callback, dispatch_uid=self.label) | ||||
|         return super().ready() | ||||
|  | ||||
|     def _on_startup_callback(self, sender, **_): | ||||
|         self._reconcile_global() | ||||
|         self._reconcile_tenant() | ||||
|         return super().ready() | ||||
|  | ||||
|     def import_related(self): | ||||
|         """Automatically import related modules which rely on just being imported | ||||
|  | ||||
| @ -4,14 +4,12 @@ from json import dumps | ||||
| from typing import Any | ||||
|  | ||||
| from django.core.management.base import BaseCommand, no_translations | ||||
| from django.db.models import Model, fields | ||||
| from drf_jsonschema_serializer.convert import converter, field_to_converter | ||||
| from django.db.models import Model | ||||
| from drf_jsonschema_serializer.convert import field_to_converter | ||||
| from rest_framework.fields import Field, JSONField, UUIDField | ||||
| from rest_framework.relations import PrimaryKeyRelatedField | ||||
| from rest_framework.serializers import Serializer | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik import __version__ | ||||
| from authentik.blueprints.v1.common import BlueprintEntryDesiredState | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, is_model_allowed | ||||
| from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry | ||||
| @ -20,23 +18,6 @@ from authentik.lib.models import SerializerModel | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @converter | ||||
| class PrimaryKeyRelatedFieldConverter: | ||||
|     """Custom primary key field converter which is aware of non-integer based PKs | ||||
|  | ||||
|     This is not an exhaustive fix for other non-int PKs, however in authentik we either | ||||
|     use UUIDs or ints""" | ||||
|  | ||||
|     field_class = PrimaryKeyRelatedField | ||||
|  | ||||
|     def convert(self, field: PrimaryKeyRelatedField): | ||||
|         model: Model = field.queryset.model | ||||
|         pk_field = model._meta.pk | ||||
|         if isinstance(pk_field, fields.UUIDField): | ||||
|             return {"type": "string", "format": "uuid"} | ||||
|         return {"type": "integer"} | ||||
|  | ||||
|  | ||||
| class Command(BaseCommand): | ||||
|     """Generate JSON Schema for blueprints""" | ||||
|  | ||||
| @ -48,7 +29,7 @@ class Command(BaseCommand): | ||||
|             "$schema": "http://json-schema.org/draft-07/schema", | ||||
|             "$id": "https://goauthentik.io/blueprints/schema.json", | ||||
|             "type": "object", | ||||
|             "title": f"authentik {__version__} Blueprint schema", | ||||
|             "title": "authentik Blueprint schema", | ||||
|             "required": ["version", "entries"], | ||||
|             "properties": { | ||||
|                 "version": { | ||||
|  | ||||
| @ -39,7 +39,7 @@ def reconcile_app(app_name: str): | ||||
|         def wrapper(*args, **kwargs): | ||||
|             config = apps.get_app_config(app_name) | ||||
|             if isinstance(config, ManagedAppConfig): | ||||
|                 config._on_startup_callback(None) | ||||
|                 config.ready() | ||||
|             return func(*args, **kwargs) | ||||
|  | ||||
|         return wrapper | ||||
|  | ||||
| @ -556,11 +556,7 @@ class BlueprintDumper(SafeDumper): | ||||
|  | ||||
|             def factory(items): | ||||
|                 final_dict = dict(items) | ||||
|                 # Remove internal state variables | ||||
|                 final_dict.pop("_state", None) | ||||
|                 # Future-proof to only remove the ID if we don't set a value | ||||
|                 if "id" in final_dict and final_dict.get("id") is None: | ||||
|                     final_dict.pop("id") | ||||
|                 return final_dict | ||||
|  | ||||
|             data = asdict(data, dict_factory=factory) | ||||
|  | ||||
| @ -19,6 +19,8 @@ from guardian.models import UserObjectPermission | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.serializers import BaseSerializer, Serializer | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
| from structlog.testing import capture_logs | ||||
| from structlog.types import EventDict | ||||
| from yaml import load | ||||
|  | ||||
| from authentik.blueprints.v1.common import ( | ||||
| @ -40,7 +42,6 @@ from authentik.core.models import ( | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.enterprise.models import LicenseUsage | ||||
| from authentik.enterprise.providers.rac.models import ConnectionToken | ||||
| from authentik.events.logs import LogEvent, capture_logs | ||||
| from authentik.events.models import SystemTask | ||||
| from authentik.events.utils import cleanse_dict | ||||
| from authentik.flows.models import FlowToken, Stage | ||||
| @ -51,8 +52,6 @@ from authentik.policies.models import Policy, PolicyBindingModel | ||||
| from authentik.policies.reputation.models import Reputation | ||||
| from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken | ||||
| from authentik.providers.scim.models import SCIMGroup, SCIMUser | ||||
| from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | ||||
| from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
| # Context set when the serializer is created in a blueprint context | ||||
| @ -97,9 +96,6 @@ def excluded_models() -> list[type[Model]]: | ||||
|         AccessToken, | ||||
|         RefreshToken, | ||||
|         Reputation, | ||||
|         WebAuthnDeviceType, | ||||
|         SCIMSourceUser, | ||||
|         SCIMSourceGroup, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @ -165,7 +161,7 @@ class Importer: | ||||
|  | ||||
|         def updater(value) -> Any: | ||||
|             if value in self.__pk_map: | ||||
|                 self.logger.debug("Updating reference in entry", value=value) | ||||
|                 self.logger.debug("updating reference in entry", value=value) | ||||
|                 return self.__pk_map[value] | ||||
|             return value | ||||
|  | ||||
| @ -254,7 +250,7 @@ class Importer: | ||||
|         model_instance = existing_models.first() | ||||
|         if not isinstance(model(), BaseMetaModel) and model_instance: | ||||
|             self.logger.debug( | ||||
|                 "Initialise serializer with instance", | ||||
|                 "initialise serializer with instance", | ||||
|                 model=model, | ||||
|                 instance=model_instance, | ||||
|                 pk=model_instance.pk, | ||||
| @ -264,14 +260,14 @@ class Importer: | ||||
|         elif model_instance and entry.state == BlueprintEntryDesiredState.MUST_CREATED: | ||||
|             raise EntryInvalidError.from_entry( | ||||
|                 ( | ||||
|                     f"State is set to {BlueprintEntryDesiredState.MUST_CREATED} " | ||||
|                     f"state is set to {BlueprintEntryDesiredState.MUST_CREATED} " | ||||
|                     "and object exists already", | ||||
|                 ), | ||||
|                 entry, | ||||
|             ) | ||||
|         else: | ||||
|             self.logger.debug( | ||||
|                 "Initialised new serializer instance", | ||||
|                 "initialised new serializer instance", | ||||
|                 model=model, | ||||
|                 **cleanse_dict(updated_identifiers), | ||||
|             ) | ||||
| @ -328,7 +324,7 @@ class Importer: | ||||
|                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||
|             except LookupError: | ||||
|                 self.logger.warning( | ||||
|                     "App or Model does not exist", app=model_app_label, model=model_name | ||||
|                     "app or model does not exist", app=model_app_label, model=model_name | ||||
|                 ) | ||||
|                 return False | ||||
|             # Validate each single entry | ||||
| @ -340,7 +336,7 @@ class Importer: | ||||
|                 if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT: | ||||
|                     serializer = exc.serializer | ||||
|                 else: | ||||
|                     self.logger.warning(f"Entry invalid: {exc}", entry=entry, error=exc) | ||||
|                     self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) | ||||
|                     if raise_errors: | ||||
|                         raise exc | ||||
|                     return False | ||||
| @ -360,14 +356,14 @@ class Importer: | ||||
|                     and state == BlueprintEntryDesiredState.CREATED | ||||
|                 ): | ||||
|                     self.logger.debug( | ||||
|                         "Instance exists, skipping", | ||||
|                         "instance exists, skipping", | ||||
|                         model=model, | ||||
|                         instance=instance, | ||||
|                         pk=instance.pk, | ||||
|                     ) | ||||
|                 else: | ||||
|                     instance = serializer.save() | ||||
|                     self.logger.debug("Updated model", model=instance) | ||||
|                     self.logger.debug("updated model", model=instance) | ||||
|                 if "pk" in entry.identifiers: | ||||
|                     self.__pk_map[entry.identifiers["pk"]] = instance.pk | ||||
|                 entry._state = BlueprintEntryState(instance) | ||||
| @ -375,12 +371,12 @@ class Importer: | ||||
|                 instance: Model | None = serializer.instance | ||||
|                 if instance.pk: | ||||
|                     instance.delete() | ||||
|                     self.logger.debug("Deleted model", mode=instance) | ||||
|                     self.logger.debug("deleted model", mode=instance) | ||||
|                     continue | ||||
|                 self.logger.debug("Entry to delete with no instance, skipping") | ||||
|                 self.logger.debug("entry to delete with no instance, skipping") | ||||
|         return True | ||||
|  | ||||
|     def validate(self, raise_validation_errors=False) -> tuple[bool, list[LogEvent]]: | ||||
|     def validate(self, raise_validation_errors=False) -> tuple[bool, list[EventDict]]: | ||||
|         """Validate loaded blueprint export, ensure all models are allowed | ||||
|         and serializers have no errors""" | ||||
|         self.logger.debug("Starting blueprint import validation") | ||||
| @ -394,7 +390,9 @@ class Importer: | ||||
|         ): | ||||
|             successful = self._apply_models(raise_errors=raise_validation_errors) | ||||
|             if not successful: | ||||
|                 self.logger.warning("Blueprint validation failed") | ||||
|                 self.logger.debug("Blueprint validation failed") | ||||
|         for log in logs: | ||||
|             getattr(self.logger, log.get("log_level"))(**log) | ||||
|         self.logger.debug("Finished blueprint import validation") | ||||
|         self._import = orig_import | ||||
|         return successful, logs | ||||
|  | ||||
| @ -30,7 +30,6 @@ from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, E | ||||
| from authentik.blueprints.v1.importer import Importer | ||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | ||||
| from authentik.blueprints.v1.oci import OCI_PREFIX | ||||
| from authentik.events.logs import capture_logs | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.events.utils import sanitize_dict | ||||
| @ -212,14 +211,13 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | ||||
|         if not valid: | ||||
|             instance.status = BlueprintInstanceStatus.ERROR | ||||
|             instance.save() | ||||
|             self.set_status(TaskStatus.ERROR, *logs) | ||||
|             self.set_status(TaskStatus.ERROR, *[x["event"] for x in logs]) | ||||
|             return | ||||
|         with capture_logs() as logs: | ||||
|         applied = importer.apply() | ||||
|         if not applied: | ||||
|             instance.status = BlueprintInstanceStatus.ERROR | ||||
|             instance.save() | ||||
|                 self.set_status(TaskStatus.ERROR, *logs) | ||||
|             self.set_status(TaskStatus.ERROR, "Failed to apply") | ||||
|             return | ||||
|         instance.status = BlueprintInstanceStatus.SUCCESSFUL | ||||
|         instance.last_applied_hash = file_hash | ||||
|  | ||||
| @ -46,6 +46,7 @@ class BrandSerializer(ModelSerializer): | ||||
|         fields = [ | ||||
|             "brand_uuid", | ||||
|             "domain", | ||||
|             "origin", | ||||
|             "default", | ||||
|             "branding_title", | ||||
|             "branding_logo", | ||||
| @ -56,6 +57,7 @@ class BrandSerializer(ModelSerializer): | ||||
|             "flow_unenrollment", | ||||
|             "flow_user_settings", | ||||
|             "flow_device_code", | ||||
|             "default_application", | ||||
|             "web_certificate", | ||||
|             "attributes", | ||||
|         ] | ||||
|  | ||||
| @ -1,12 +1,17 @@ | ||||
| """Inject brand into current request""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from django.http.request import HttpRequest | ||||
| from django.http.response import HttpResponse | ||||
| from django.utils.translation import activate | ||||
|  | ||||
| from authentik.brands.utils import get_brand_for_request | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from authentik.brands.models import Brand | ||||
|  | ||||
|  | ||||
| class BrandMiddleware: | ||||
| @ -25,3 +30,41 @@ class BrandMiddleware: | ||||
|             if locale != "": | ||||
|                 activate(locale) | ||||
|         return self.get_response(request) | ||||
|  | ||||
|  | ||||
| class BrandHeaderMiddleware: | ||||
|     """Add headers from currently active brand""" | ||||
|  | ||||
|     get_response: Callable[[HttpRequest], HttpResponse] | ||||
|     default_csp_elements: dict[str, list[str]] = {} | ||||
|  | ||||
|     def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): | ||||
|         self.get_response = get_response | ||||
|         self.default_csp_elements = { | ||||
|             "style-src": ["'self'", "'unsafe-inline'"],  # Required due to Lit/ShadowDOM | ||||
|             "script-src": ["'self'", "'unsafe-inline'"],  # Required for generated scripts | ||||
|             "img-src": ["https:", "http:", "data:"], | ||||
|             "default-src": ["'self'"], | ||||
|             "object-src": ["'none'"], | ||||
|             "connect-src": ["'self'"], | ||||
|         } | ||||
|         if CONFIG.get_bool("error_reporting.enabled"): | ||||
|             self.default_csp_elements["connect-src"].append( | ||||
|                 # Required for sentry (TODO: Dynamic) | ||||
|                 "https://authentik.error-reporting.a7k.io" | ||||
|             ) | ||||
|             if CONFIG.get_bool("debug"): | ||||
|                 # Also allow spotlight sidecar connection | ||||
|                 self.default_csp_elements["connect-src"].append("http://localhost:8969") | ||||
|  | ||||
|     def get_csp(self, request: HttpRequest) -> str: | ||||
|         brand: "Brand" = request.brand | ||||
|         elements = self.default_csp_elements.copy() | ||||
|         if brand.origin != "": | ||||
|             elements["frame-ancestors"] = [brand.origin] | ||||
|         return ";".join(f"{attr} {" ".join(value)}" for attr, value in elements.items()) | ||||
|  | ||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||
|         response = self.get_response(request) | ||||
|         response.headers["Content-Security-Policy"] = self.get_csp(request) | ||||
|         return response | ||||
|  | ||||
| @ -1,21 +0,0 @@ | ||||
| # Generated by Django 5.0.4 on 2024-04-18 18:56 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_brands", "0005_tenantuuid_to_branduuid"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddIndex( | ||||
|             model_name="brand", | ||||
|             index=models.Index(fields=["domain"], name="authentik_b_domain_b9b24a_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="brand", | ||||
|             index=models.Index(fields=["default"], name="authentik_b_default_3ccf12_idx"), | ||||
|         ), | ||||
|     ] | ||||
| @ -0,0 +1,26 @@ | ||||
| # Generated by Django 5.0.3 on 2024-03-21 15:42 | ||||
|  | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_brands", "0005_tenantuuid_to_branduuid"), | ||||
|         ("authentik_core", "0033_alter_user_options"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="brand", | ||||
|             name="default_application", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 help_text="When set, external users will be redirected to this application after authenticating.", | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.application", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
							
								
								
									
										21
									
								
								authentik/brands/migrations/0007_brand_origin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								authentik/brands/migrations/0007_brand_origin.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| # Generated by Django 5.0.3 on 2024-03-26 14:17 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_brands", "0006_brand_default_application"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="brand", | ||||
|             name="origin", | ||||
|             field=models.TextField( | ||||
|                 blank=True, | ||||
|                 help_text="Origin domain that activates this brand. Can be left empty to not allow any origins.", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -23,6 +23,12 @@ class Brand(SerializerModel): | ||||
|             "Domain that activates this brand. Can be a superset, i.e. `a.b` for `aa.b` and `ba.b`" | ||||
|         ) | ||||
|     ) | ||||
|     origin = models.TextField( | ||||
|         help_text=_( | ||||
|             "Origin domain that activates this brand. Can be left empty to not allow any origins." | ||||
|         ), | ||||
|         blank=True, | ||||
|     ) | ||||
|     default = models.BooleanField( | ||||
|         default=False, | ||||
|     ) | ||||
| @ -51,6 +57,16 @@ class Brand(SerializerModel): | ||||
|         Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code" | ||||
|     ) | ||||
|  | ||||
|     default_application = models.ForeignKey( | ||||
|         "authentik_core.Application", | ||||
|         null=True, | ||||
|         default=None, | ||||
|         on_delete=models.SET_DEFAULT, | ||||
|         help_text=_( | ||||
|             "When set, external users will be redirected to this application after authenticating." | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     web_certificate = models.ForeignKey( | ||||
|         CertificateKeyPair, | ||||
|         null=True, | ||||
| @ -84,7 +100,3 @@ class Brand(SerializerModel): | ||||
|     class Meta: | ||||
|         verbose_name = _("Brand") | ||||
|         verbose_name_plural = _("Brands") | ||||
|         indexes = [ | ||||
|             models.Index(fields=["domain"]), | ||||
|             models.Index(fields=["default"]), | ||||
|         ] | ||||
|  | ||||
| @ -1,11 +1,15 @@ | ||||
| """Brand utilities""" | ||||
|  | ||||
| from typing import Any | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| from django.db.models import F, Q | ||||
| from django.db.models import Value as V | ||||
| from django.http import HttpResponse | ||||
| from django.http.request import HttpRequest | ||||
| from django.utils.cache import patch_vary_headers | ||||
| from sentry_sdk.hub import Hub | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.brands.models import Brand | ||||
| @ -13,13 +17,17 @@ from authentik.tenants.models import Tenant | ||||
|  | ||||
| _q_default = Q(default=True) | ||||
| DEFAULT_BRAND = Brand(domain="fallback") | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| def get_brand_for_request(request: HttpRequest) -> Brand: | ||||
|     """Get brand object for current request""" | ||||
|     query = Q(host_domain__iendswith=F("domain")) | ||||
|     if "Origin" in request.headers: | ||||
|         query &= Q(Q(origin=request.headers.get("Origin", "")) | Q(origin="")) | ||||
|     db_brands = ( | ||||
|         Brand.objects.annotate(host_domain=V(request.get_host())) | ||||
|         .filter(Q(host_domain__iendswith=F("domain")) | _q_default) | ||||
|         .filter(Q(query) | _q_default) | ||||
|         .order_by("default") | ||||
|     ) | ||||
|     brands = list(db_brands.all()) | ||||
| @ -42,3 +50,46 @@ def context_processor(request: HttpRequest) -> dict[str, Any]: | ||||
|         "sentry_trace": trace, | ||||
|         "version": get_full_version(), | ||||
|     } | ||||
|  | ||||
|  | ||||
| def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: str): | ||||
|     """Add headers to permit CORS requests from allowed_origins, with or without credentials, | ||||
|     with any headers.""" | ||||
|     origin = request.META.get("HTTP_ORIGIN") | ||||
|     if not origin: | ||||
|         return response | ||||
|  | ||||
|     # OPTIONS requests don't have an authorization header -> hence | ||||
|     # we can't extract the provider this request is for | ||||
|     # so for options requests we allow the calling origin without checking | ||||
|     allowed = request.method == "OPTIONS" | ||||
|     received_origin = urlparse(origin) | ||||
|     for allowed_origin in allowed_origins: | ||||
|         url = urlparse(allowed_origin) | ||||
|         if ( | ||||
|             received_origin.scheme == url.scheme | ||||
|             and received_origin.hostname == url.hostname | ||||
|             and received_origin.port == url.port | ||||
|         ): | ||||
|             allowed = True | ||||
|     if not allowed: | ||||
|         LOGGER.warning( | ||||
|             "CORS: Origin is not an allowed origin", | ||||
|             requested=received_origin, | ||||
|             allowed=allowed_origins, | ||||
|         ) | ||||
|         return response | ||||
|  | ||||
|     # From the CORS spec: The string "*" cannot be used for a resource that supports credentials. | ||||
|     response["Access-Control-Allow-Origin"] = origin | ||||
|     patch_vary_headers(response, ["Origin"]) | ||||
|     response["Access-Control-Allow-Credentials"] = "true" | ||||
|  | ||||
|     if request.method == "OPTIONS": | ||||
|         if "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" in request.META: | ||||
|             response["Access-Control-Allow-Headers"] = request.META[ | ||||
|                 "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" | ||||
|             ] | ||||
|         response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" | ||||
|  | ||||
|     return response | ||||
|  | ||||
| @ -20,15 +20,15 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
| from structlog.testing import capture_logs | ||||
|  | ||||
| from authentik.admin.api.metrics import CoordinateSerializer | ||||
| from authentik.api.pagination import Pagination | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||
| from authentik.core.api.providers import ProviderSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.models import Application, User | ||||
| from authentik.events.logs import LogEventSerializer, capture_logs | ||||
| from authentik.events.models import EventAction | ||||
| from authentik.events.utils import sanitize_dict | ||||
| from authentik.lib.utils.file import ( | ||||
|     FilePathSerializer, | ||||
|     FileUploadSerializer, | ||||
| @ -44,12 +44,9 @@ from authentik.rbac.filters import ObjectFilter | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| def user_app_cache_key(user_pk: str, page_number: int | None = None) -> str: | ||||
| def user_app_cache_key(user_pk: str) -> str: | ||||
|     """Cache key where application list for user is saved""" | ||||
|     key = f"{CACHE_PREFIX}/app_access/{user_pk}" | ||||
|     if page_number: | ||||
|         key += f"/{page_number}" | ||||
|     return key | ||||
|     return f"{CACHE_PREFIX}/app_access/{user_pk}" | ||||
|  | ||||
|  | ||||
| class ApplicationSerializer(ModelSerializer): | ||||
| @ -185,9 +182,9 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|         if request.user.is_superuser: | ||||
|             log_messages = [] | ||||
|             for log in logs: | ||||
|                 if log.attributes.get("process", "") == "PolicyProcess": | ||||
|                 if log.get("process", "") == "PolicyProcess": | ||||
|                     continue | ||||
|                 log_messages.append(LogEventSerializer(log).data) | ||||
|                 log_messages.append(sanitize_dict(log)) | ||||
|             result.log_messages = log_messages | ||||
|             response = PolicyTestResultSerializer(result) | ||||
|         return Response(response.data) | ||||
| @ -217,8 +214,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|             return super().list(request) | ||||
|  | ||||
|         queryset = self._filter_queryset_for_list(self.get_queryset()) | ||||
|         paginator: Pagination = self.paginator | ||||
|         paginated_apps = paginator.paginate_queryset(queryset, request) | ||||
|         paginated_apps = self.paginate_queryset(queryset) | ||||
|  | ||||
|         if "for_user" in request.query_params: | ||||
|             try: | ||||
| @ -240,14 +236,12 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|         if not should_cache: | ||||
|             allowed_applications = self._get_allowed_applications(paginated_apps) | ||||
|         if should_cache: | ||||
|             allowed_applications = cache.get( | ||||
|                 user_app_cache_key(self.request.user.pk, paginator.page.number) | ||||
|             ) | ||||
|             allowed_applications = cache.get(user_app_cache_key(self.request.user.pk)) | ||||
|             if not allowed_applications: | ||||
|                 LOGGER.debug("Caching allowed application list", page=paginator.page.number) | ||||
|                 LOGGER.debug("Caching allowed application list") | ||||
|                 allowed_applications = self._get_allowed_applications(paginated_apps) | ||||
|                 cache.set( | ||||
|                     user_app_cache_key(self.request.user.pk, paginator.page.number), | ||||
|                     user_app_cache_key(self.request.user.pk), | ||||
|                     allowed_applications, | ||||
|                     timeout=86400, | ||||
|                 ) | ||||
|  | ||||
| @ -5,15 +5,10 @@ from json import loads | ||||
| from django.http import Http404 | ||||
| from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | ||||
| from django_filters.filterset import FilterSet | ||||
| from drf_spectacular.utils import ( | ||||
|     OpenApiParameter, | ||||
|     OpenApiResponse, | ||||
|     extend_schema, | ||||
|     extend_schema_field, | ||||
| ) | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.fields import CharField, IntegerField, SerializerMethodField | ||||
| from rest_framework.fields import CharField, IntegerField | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError | ||||
| @ -50,7 +45,9 @@ class GroupSerializer(ModelSerializer): | ||||
|     """Group Serializer""" | ||||
|  | ||||
|     attributes = JSONDictField(required=False) | ||||
|     users_obj = SerializerMethodField(allow_null=True) | ||||
|     users_obj = ListSerializer( | ||||
|         child=GroupMemberSerializer(), read_only=True, source="users", required=False | ||||
|     ) | ||||
|     roles_obj = ListSerializer( | ||||
|         child=RoleSerializer(), | ||||
|         read_only=True, | ||||
| @ -61,19 +58,6 @@ class GroupSerializer(ModelSerializer): | ||||
|  | ||||
|     num_pk = IntegerField(read_only=True) | ||||
|  | ||||
|     @property | ||||
|     def _should_include_users(self) -> bool: | ||||
|         request: Request = self.context.get("request", None) | ||||
|         if not request: | ||||
|             return True | ||||
|         return str(request.query_params.get("include_users", "true")).lower() == "true" | ||||
|  | ||||
|     @extend_schema_field(GroupMemberSerializer(many=True)) | ||||
|     def get_users_obj(self, instance: Group) -> list[GroupMemberSerializer] | None: | ||||
|         if not self._should_include_users: | ||||
|             return None | ||||
|         return GroupMemberSerializer(instance.users, many=True).data | ||||
|  | ||||
|     def validate_parent(self, parent: Group | None): | ||||
|         """Validate group parent (if set), ensuring the parent isn't itself""" | ||||
|         if not self.instance or not parent: | ||||
| @ -146,35 +130,22 @@ class GroupFilter(FilterSet): | ||||
|         fields = ["name", "is_superuser", "members_by_pk", "attributes", "members_by_username"] | ||||
|  | ||||
|  | ||||
| class GroupViewSet(UsedByMixin, ModelViewSet): | ||||
|     """Group Viewset""" | ||||
|  | ||||
| class UserAccountSerializer(PassiveSerializer): | ||||
|     """Account adding/removing operations""" | ||||
|  | ||||
|     pk = IntegerField(required=True) | ||||
|  | ||||
|     queryset = Group.objects.none() | ||||
|  | ||||
| class GroupViewSet(UsedByMixin, ModelViewSet): | ||||
|     """Group Viewset""" | ||||
|  | ||||
|     queryset = Group.objects.all().select_related("parent").prefetch_related("users") | ||||
|     serializer_class = GroupSerializer | ||||
|     search_fields = ["name", "is_superuser"] | ||||
|     filterset_class = GroupFilter | ||||
|     ordering = ["name"] | ||||
|  | ||||
|     def get_queryset(self): | ||||
|         base_qs = Group.objects.all().select_related("parent").prefetch_related("roles") | ||||
|         if self.serializer_class(context={"request": self.request})._should_include_users: | ||||
|             base_qs = base_qs.prefetch_related("users") | ||||
|         return base_qs | ||||
|  | ||||
|     @extend_schema( | ||||
|         parameters=[ | ||||
|             OpenApiParameter("include_users", bool, default=True), | ||||
|         ] | ||||
|     ) | ||||
|     def list(self, request, *args, **kwargs): | ||||
|         return super().list(request, *args, **kwargs) | ||||
|  | ||||
|     @permission_required("authentik_core.add_user_to_group") | ||||
|     @permission_required(None, ["authentik_core.add_user"]) | ||||
|     @extend_schema( | ||||
|         request=UserAccountSerializer, | ||||
|         responses={ | ||||
| @ -182,13 +153,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet): | ||||
|             404: OpenApiResponse(description="User not found"), | ||||
|         }, | ||||
|     ) | ||||
|     @action( | ||||
|         detail=True, | ||||
|         methods=["POST"], | ||||
|         pagination_class=None, | ||||
|         filter_backends=[], | ||||
|         permission_classes=[], | ||||
|     ) | ||||
|     @action(detail=True, methods=["POST"], pagination_class=None, filter_backends=[]) | ||||
|     def add_user(self, request: Request, pk: str) -> Response: | ||||
|         """Add user to group""" | ||||
|         group: Group = self.get_object() | ||||
| @ -204,7 +169,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet): | ||||
|         group.users.add(user) | ||||
|         return Response(status=204) | ||||
|  | ||||
|     @permission_required("authentik_core.remove_user_from_group") | ||||
|     @permission_required(None, ["authentik_core.add_user"]) | ||||
|     @extend_schema( | ||||
|         request=UserAccountSerializer, | ||||
|         responses={ | ||||
| @ -212,13 +177,7 @@ class GroupViewSet(UsedByMixin, ModelViewSet): | ||||
|             404: OpenApiResponse(description="User not found"), | ||||
|         }, | ||||
|     ) | ||||
|     @action( | ||||
|         detail=True, | ||||
|         methods=["POST"], | ||||
|         pagination_class=None, | ||||
|         filter_backends=[], | ||||
|         permission_classes=[], | ||||
|     ) | ||||
|     @action(detail=True, methods=["POST"], pagination_class=None, filter_backends=[]) | ||||
|     def remove_user(self, request: Request, pk: str) -> Response: | ||||
|         """Add user to group""" | ||||
|         group: Group = self.get_object() | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from django.utils.timezone import now | ||||
| from django_filters.rest_framework import DjangoFilterBackend | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer | ||||
| from guardian.shortcuts import assign_perm, get_anonymous_user | ||||
| @ -21,17 +20,9 @@ from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.users import UserSerializer | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.core.models import ( | ||||
|     USER_ATTRIBUTE_TOKEN_EXPIRING, | ||||
|     USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, | ||||
|     Token, | ||||
|     TokenIntents, | ||||
|     User, | ||||
|     default_token_duration, | ||||
| ) | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.utils import model_to_dict | ||||
| from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| @ -45,13 +36,6 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | ||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||
|             self.fields["key"] = CharField(required=False) | ||||
|  | ||||
|     def validate_user(self, user: User): | ||||
|         """Ensure user of token cannot be changed""" | ||||
|         if self.instance and self.instance.user_id: | ||||
|             if user.pk != self.instance.user_id: | ||||
|                 raise ValidationError("User cannot be changed") | ||||
|         return user | ||||
|  | ||||
|     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: | ||||
|         """Ensure only API or App password tokens are created.""" | ||||
|         request: Request = self.context.get("request") | ||||
| @ -65,32 +49,6 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | ||||
|         attrs.setdefault("intent", TokenIntents.INTENT_API) | ||||
|         if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]: | ||||
|             raise ValidationError({"intent": f"Invalid intent {attrs.get('intent')}"}) | ||||
|  | ||||
|         if attrs.get("intent") == TokenIntents.INTENT_APP_PASSWORD: | ||||
|             # user IS in attrs | ||||
|             user: User = attrs.get("user") | ||||
|             max_token_lifetime = user.group_attributes(request).get( | ||||
|                 USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, | ||||
|             ) | ||||
|             max_token_lifetime_dt = default_token_duration() | ||||
|             if max_token_lifetime is not None: | ||||
|                 try: | ||||
|                     max_token_lifetime_dt = now() + timedelta_from_string(max_token_lifetime) | ||||
|                 except ValueError: | ||||
|                     pass | ||||
|  | ||||
|             if "expires" in attrs and attrs.get("expires") > max_token_lifetime_dt: | ||||
|                 raise ValidationError( | ||||
|                     { | ||||
|                         "expires": ( | ||||
|                             f"Token expires exceeds maximum lifetime ({max_token_lifetime_dt} UTC)." | ||||
|                         ) | ||||
|                     } | ||||
|                 ) | ||||
|         elif attrs.get("intent") == TokenIntents.INTENT_API: | ||||
|             # For API tokens, expires cannot be overridden | ||||
|             attrs["expires"] = default_token_duration() | ||||
|  | ||||
|         return attrs | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -85,7 +85,7 @@ class UserGroupSerializer(ModelSerializer): | ||||
|     """Simplified Group Serializer for user's groups""" | ||||
|  | ||||
|     attributes = JSONDictField(required=False) | ||||
|     parent_name = CharField(source="parent.name", read_only=True, allow_null=True) | ||||
|     parent_name = CharField(source="parent.name", read_only=True) | ||||
|  | ||||
|     class Meta: | ||||
|         model = Group | ||||
| @ -113,26 +113,13 @@ class UserSerializer(ModelSerializer): | ||||
|         queryset=Group.objects.all().order_by("name"), | ||||
|         default=list, | ||||
|     ) | ||||
|     groups_obj = SerializerMethodField(allow_null=True) | ||||
|     groups_obj = ListSerializer(child=UserGroupSerializer(), read_only=True, source="ak_groups") | ||||
|     uid = CharField(read_only=True) | ||||
|     username = CharField( | ||||
|         max_length=150, | ||||
|         validators=[UniqueValidator(queryset=User.objects.all().order_by("username"))], | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def _should_include_groups(self) -> bool: | ||||
|         request: Request = self.context.get("request", None) | ||||
|         if not request: | ||||
|             return True | ||||
|         return str(request.query_params.get("include_groups", "true")).lower() == "true" | ||||
|  | ||||
|     @extend_schema_field(UserGroupSerializer(many=True)) | ||||
|     def get_groups_obj(self, instance: User) -> list[UserGroupSerializer] | None: | ||||
|         if not self._should_include_groups: | ||||
|             return None | ||||
|         return UserGroupSerializer(instance.ak_groups, many=True).data | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||
| @ -407,19 +394,8 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|     search_fields = ["username", "name", "is_active", "email", "uuid"] | ||||
|     filterset_class = UsersFilter | ||||
|  | ||||
|     def get_queryset(self): | ||||
|         base_qs = User.objects.all().exclude_anonymous() | ||||
|         if self.serializer_class(context={"request": self.request})._should_include_groups: | ||||
|             base_qs = base_qs.prefetch_related("ak_groups") | ||||
|         return base_qs | ||||
|  | ||||
|     @extend_schema( | ||||
|         parameters=[ | ||||
|             OpenApiParameter("include_groups", bool, default=True), | ||||
|         ] | ||||
|     ) | ||||
|     def list(self, request, *args, **kwargs): | ||||
|         return super().list(request, *args, **kwargs) | ||||
|     def get_queryset(self):  # pragma: no cover | ||||
|         return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") | ||||
|  | ||||
|     def _create_recovery_link(self) -> tuple[str, Token]: | ||||
|         """Create a recovery link (when the current brand has a recovery flow set), | ||||
|  | ||||
| @ -1,34 +1,10 @@ | ||||
| """custom runserver command""" | ||||
|  | ||||
| from typing import TextIO | ||||
|  | ||||
| from daphne.management.commands.runserver import Command as RunServer | ||||
| from daphne.server import Server | ||||
|  | ||||
| from authentik.root.signals import post_startup, pre_startup, startup | ||||
|  | ||||
|  | ||||
| class SignalServer(Server): | ||||
|     """Server which signals back to authentik when it finished starting up""" | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|         def ready_callable(): | ||||
|             pre_startup.send(sender=self) | ||||
|             startup.send(sender=self) | ||||
|             post_startup.send(sender=self) | ||||
|  | ||||
|         self.ready_callable = ready_callable | ||||
|  | ||||
|  | ||||
| class Command(RunServer): | ||||
|     """custom runserver command, which doesn't show the misleading django startup message""" | ||||
|  | ||||
|     server_cls = SignalServer | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         # Redirect standard stdout banner from Daphne into the void | ||||
|         # as there are a couple more steps that happen before startup is fully done | ||||
|         self.stdout = TextIO() | ||||
|     def on_bind(self, server_port): | ||||
|         pass | ||||
|  | ||||
| @ -5,7 +5,6 @@ from django.db import migrations, models | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
| import authentik.core.models | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| def set_default_token_key(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
| @ -17,10 +16,6 @@ def set_default_token_key(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|         token.save() | ||||
|  | ||||
|  | ||||
| def default_token_key(): | ||||
|     return generate_id(60) | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|     replaces = [ | ||||
|         ("authentik_core", "0012_auto_20201003_1737"), | ||||
| @ -67,7 +62,7 @@ class Migration(migrations.Migration): | ||||
|         migrations.AddField( | ||||
|             model_name="token", | ||||
|             name="key", | ||||
|             field=models.TextField(default=default_token_key), | ||||
|             field=models.TextField(default=authentik.core.models.default_token_key), | ||||
|         ), | ||||
|         migrations.AlterUniqueTogether( | ||||
|             name="token", | ||||
|  | ||||
| @ -1,31 +0,0 @@ | ||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
| import authentik.core.models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0033_alter_user_options"), | ||||
|         ("authentik_tenants", "0002_tenant_default_token_duration_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterField( | ||||
|             model_name="authenticatedsession", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="token", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="token", | ||||
|             name="key", | ||||
|             field=models.TextField(default=authentik.core.models.default_token_key), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,52 +0,0 @@ | ||||
| # Generated by Django 5.0.4 on 2024-04-15 11:28 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("auth", "0012_alter_user_first_name_max_length"), | ||||
|         ("authentik_core", "0034_alter_authenticatedsession_expires_and_more"), | ||||
|         ("authentik_rbac", "0003_alter_systempermission_options"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterModelOptions( | ||||
|             name="group", | ||||
|             options={ | ||||
|                 "permissions": [ | ||||
|                     ("add_user_to_group", "Add user to group"), | ||||
|                     ("remove_user_from_group", "Remove user from group"), | ||||
|                 ], | ||||
|                 "verbose_name": "Group", | ||||
|                 "verbose_name_plural": "Groups", | ||||
|             }, | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="group", | ||||
|             index=models.Index(fields=["name"], name="authentik_c_name_9ba8e4_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="user", | ||||
|             index=models.Index(fields=["last_login"], name="authentik_c_last_lo_f0179a_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="user", | ||||
|             index=models.Index( | ||||
|                 fields=["password_change_date"], name="authentik_c_passwor_eec915_idx" | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="user", | ||||
|             index=models.Index(fields=["uuid"], name="authentik_c_uuid_3dae2f_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="user", | ||||
|             index=models.Index(fields=["path"], name="authentik_c_path_b1f502_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="user", | ||||
|             index=models.Index(fields=["type"], name="authentik_c_type_ecf60d_idx"), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,6 +1,6 @@ | ||||
| """authentik core models""" | ||||
|  | ||||
| from datetime import datetime | ||||
| from datetime import timedelta | ||||
| from hashlib import sha256 | ||||
| from typing import Any, Optional, Self | ||||
| from uuid import uuid4 | ||||
| @ -25,16 +25,15 @@ from authentik.blueprints.models import ManagedModel | ||||
| from authentik.core.exceptions import PropertyMappingExpressionException | ||||
| from authentik.core.types import UILoginButton, UserSettingSerializer | ||||
| from authentik.lib.avatars import get_avatar | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.lib.models import ( | ||||
|     CreatedUpdatedModel, | ||||
|     DomainlessFormattedURLValidator, | ||||
|     SerializerModel, | ||||
| ) | ||||
| from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.policies.models import PolicyBindingModel | ||||
| from authentik.tenants.models import DEFAULT_TOKEN_DURATION, DEFAULT_TOKEN_LENGTH | ||||
| from authentik.tenants.utils import get_current_tenant, get_unique_identifier | ||||
| from authentik.tenants.utils import get_unique_identifier | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| USER_ATTRIBUTE_DEBUG = "goauthentik.io/user/debug" | ||||
| @ -43,42 +42,33 @@ USER_ATTRIBUTE_EXPIRES = "goauthentik.io/user/expires" | ||||
| USER_ATTRIBUTE_DELETE_ON_LOGOUT = "goauthentik.io/user/delete-on-logout" | ||||
| USER_ATTRIBUTE_SOURCES = "goauthentik.io/user/sources" | ||||
| USER_ATTRIBUTE_TOKEN_EXPIRING = "goauthentik.io/user/token-expires"  # nosec | ||||
| USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME = "goauthentik.io/user/token-maximum-lifetime"  # nosec | ||||
| USER_ATTRIBUTE_CHANGE_USERNAME = "goauthentik.io/user/can-change-username" | ||||
| USER_ATTRIBUTE_CHANGE_NAME = "goauthentik.io/user/can-change-name" | ||||
| USER_ATTRIBUTE_CHANGE_EMAIL = "goauthentik.io/user/can-change-email" | ||||
| USER_PATH_SYSTEM_PREFIX = "goauthentik.io" | ||||
| USER_PATH_SERVICE_ACCOUNT = USER_PATH_SYSTEM_PREFIX + "/service-accounts" | ||||
|  | ||||
|  | ||||
| options.DEFAULT_NAMES = options.DEFAULT_NAMES + ( | ||||
|     # used_by API that allows models to specify if they shadow an object | ||||
|     # for example the proxy provider which is built on top of an oauth provider | ||||
|     "authentik_used_by_shadows", | ||||
|     # List fields for which changes are not logged (due to them having dedicated objects) | ||||
|     # for example user's password and last_login | ||||
|     "authentik_signals_ignored_fields", | ||||
| ) | ||||
|  | ||||
|  | ||||
| def default_token_duration() -> datetime: | ||||
| def default_token_duration(): | ||||
|     """Default duration a Token is valid""" | ||||
|     current_tenant = get_current_tenant() | ||||
|     token_duration = ( | ||||
|         current_tenant.default_token_duration | ||||
|         if hasattr(current_tenant, "default_token_duration") | ||||
|         else DEFAULT_TOKEN_DURATION | ||||
|     ) | ||||
|     return now() + timedelta_from_string(token_duration) | ||||
|     return now() + timedelta(minutes=30) | ||||
|  | ||||
|  | ||||
| def default_token_key() -> str: | ||||
| def default_token_key(): | ||||
|     """Default token key""" | ||||
|     current_tenant = get_current_tenant() | ||||
|     token_length = ( | ||||
|         current_tenant.default_token_length | ||||
|         if hasattr(current_tenant, "default_token_length") | ||||
|         else DEFAULT_TOKEN_LENGTH | ||||
|     ) | ||||
|     # We use generate_id since the chars in the key should be easy | ||||
|     # to use in Emails (for verification) and URLs (for recovery) | ||||
|     return generate_id(token_length) | ||||
|     return generate_id(CONFIG.get_int("default_token_length")) | ||||
|  | ||||
|  | ||||
| class UserTypes(models.TextChoices): | ||||
| @ -177,13 +167,8 @@ class Group(SerializerModel): | ||||
|                 "parent", | ||||
|             ), | ||||
|         ) | ||||
|         indexes = [models.Index(fields=["name"])] | ||||
|         verbose_name = _("Group") | ||||
|         verbose_name_plural = _("Groups") | ||||
|         permissions = [ | ||||
|             ("add_user_to_group", _("Add user to group")), | ||||
|             ("remove_user_from_group", _("Remove user from group")), | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class UserQuerySet(models.QuerySet): | ||||
| @ -320,12 +305,13 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): | ||||
|             ("preview_user", _("Can preview user data sent to providers")), | ||||
|             ("view_user_applications", _("View applications the user has access to")), | ||||
|         ] | ||||
|         indexes = [ | ||||
|             models.Index(fields=["last_login"]), | ||||
|             models.Index(fields=["password_change_date"]), | ||||
|             models.Index(fields=["uuid"]), | ||||
|             models.Index(fields=["path"]), | ||||
|             models.Index(fields=["type"]), | ||||
|         authentik_signals_ignored_fields = [ | ||||
|             # Logged by the events `password_set` | ||||
|             # the `password_set` action/signal doesn't currently convey which user | ||||
|             # initiated the password change, so for now we'll log two actions | ||||
|             # ("password", "password_change_date"), | ||||
|             # Logged by `login` | ||||
|             ("last_login",), | ||||
|         ] | ||||
|  | ||||
|  | ||||
| @ -632,7 +618,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"User-source connection (user={self.user_id}, source={self.source_id})" | ||||
|         return f"User-source connection (user={self.user.username}, source={self.source.slug})" | ||||
|  | ||||
|     class Meta: | ||||
|         unique_together = (("user", "source"),) | ||||
| @ -641,7 +627,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel): | ||||
| class ExpiringModel(models.Model): | ||||
|     """Base Model which can expire, and is automatically cleaned up.""" | ||||
|  | ||||
|     expires = models.DateTimeField(default=None, null=True) | ||||
|     expires = models.DateTimeField(default=default_token_duration) | ||||
|     expiring = models.BooleanField(default=True) | ||||
|  | ||||
|     class Meta: | ||||
| @ -655,7 +641,7 @@ class ExpiringModel(models.Model): | ||||
|         return self.delete(*args, **kwargs) | ||||
|  | ||||
|     @classmethod | ||||
|     def filter_not_expired(cls, **kwargs) -> QuerySet["Token"]: | ||||
|     def filter_not_expired(cls, **kwargs) -> QuerySet: | ||||
|         """Filer for tokens which are not expired yet or are not expiring, | ||||
|         and match filters in `kwargs`""" | ||||
|         for obj in cls.objects.filter(**kwargs).filter(Q(expires__lt=now(), expiring=True)): | ||||
|  | ||||
| @ -10,14 +10,7 @@ from django.dispatch import receiver | ||||
| from django.http.request import HttpRequest | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import ( | ||||
|     Application, | ||||
|     AuthenticatedSession, | ||||
|     BackchannelProvider, | ||||
|     ExpiringModel, | ||||
|     User, | ||||
|     default_token_duration, | ||||
| ) | ||||
| from authentik.core.models import Application, AuthenticatedSession, BackchannelProvider, User | ||||
|  | ||||
| # Arguments: user: User, password: str | ||||
| password_changed = Signal() | ||||
| @ -68,12 +61,3 @@ def backchannel_provider_pre_save(sender: type[Model], instance: Model, **_): | ||||
|     if not isinstance(instance, BackchannelProvider): | ||||
|         return | ||||
|     instance.is_backchannel = True | ||||
|  | ||||
|  | ||||
| @receiver(pre_save) | ||||
| def expiring_model_pre_save(sender: type[Model], instance: Model, **_): | ||||
|     """Ensure expires is set on ExpiringModels that are set to expire""" | ||||
|     if not issubclass(sender, ExpiringModel): | ||||
|         return | ||||
|     if instance.expiring and instance.expires is None: | ||||
|         instance.expires = default_token_duration() | ||||
|  | ||||
| @ -13,7 +13,7 @@ from django.utils.translation import gettext as _ | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection | ||||
| from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage | ||||
| from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.flows.exceptions import FlowNonApplicableException | ||||
| from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage | ||||
| @ -100,6 +100,8 @@ class SourceFlowManager: | ||||
|         if self.request.user.is_authenticated: | ||||
|             new_connection.user = self.request.user | ||||
|             new_connection = self.update_connection(new_connection, **kwargs) | ||||
|  | ||||
|             new_connection.save() | ||||
|             return Action.LINK, new_connection | ||||
|  | ||||
|         existing_connections = self.connection_type.objects.filter( | ||||
| @ -146,6 +148,7 @@ class SourceFlowManager: | ||||
|         ]: | ||||
|             new_connection.user = user | ||||
|             new_connection = self.update_connection(new_connection, **kwargs) | ||||
|             new_connection.save() | ||||
|             return Action.LINK, new_connection | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.EMAIL_DENY, | ||||
| @ -206,9 +209,13 @@ class SourceFlowManager: | ||||
|  | ||||
|     def get_stages_to_append(self, flow: Flow) -> list[Stage]: | ||||
|         """Hook to override stages which are appended to the flow""" | ||||
|         if not self.source.enrollment_flow: | ||||
|             return [] | ||||
|         if flow.slug == self.source.enrollment_flow.slug: | ||||
|             return [ | ||||
|             in_memory_stage(PostSourceStage), | ||||
|                 in_memory_stage(PostUserEnrollmentStage), | ||||
|             ] | ||||
|         return [] | ||||
|  | ||||
|     def _prepare_flow( | ||||
|         self, | ||||
| @ -262,9 +269,6 @@ class SourceFlowManager: | ||||
|             ) | ||||
|         # We run the Flow planner here so we can pass the Pending user in the context | ||||
|         planner = FlowPlanner(flow) | ||||
|         # We append some stages so the initial flow we get might be empty | ||||
|         planner.allow_empty_flows = True | ||||
|         planner.use_cache = False | ||||
|         plan = planner.plan(self.request, kwargs) | ||||
|         for stage in self.get_stages_to_append(flow): | ||||
|             plan.append_stage(stage) | ||||
| @ -323,7 +327,7 @@ class SourceFlowManager: | ||||
|             reverse( | ||||
|                 "authentik_core:if-user", | ||||
|             ) | ||||
|             + "#/settings;page-sources" | ||||
|             + f"#/settings;page-{self.source.slug}" | ||||
|         ) | ||||
|  | ||||
|     def handle_enroll( | ||||
|  | ||||
| @ -10,7 +10,7 @@ from authentik.flows.stage import StageView | ||||
| PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection" | ||||
|  | ||||
|  | ||||
| class PostSourceStage(StageView): | ||||
| class PostUserEnrollmentStage(StageView): | ||||
|     """Dynamically injected stage which saves the Connection after | ||||
|     the user has been enrolled.""" | ||||
|  | ||||
| @ -21,9 +21,7 @@ class PostSourceStage(StageView): | ||||
|         ] | ||||
|         user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] | ||||
|         connection.user = user | ||||
|         linked = connection.pk is None | ||||
|         connection.save() | ||||
|         if linked: | ||||
|         Event.new( | ||||
|             EventAction.SOURCE_LINKED, | ||||
|             message="Linked Source", | ||||
|  | ||||
| @ -2,9 +2,7 @@ | ||||
|  | ||||
| from datetime import datetime, timedelta | ||||
|  | ||||
| from django.conf import ImproperlyConfigured | ||||
| from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||
| from django.contrib.sessions.backends.db import SessionStore as DBSessionStore | ||||
| from django.core.cache import cache | ||||
| from django.utils.timezone import now | ||||
| from structlog.stdlib import get_logger | ||||
| @ -17,7 +15,6 @@ from authentik.core.models import ( | ||||
|     User, | ||||
| ) | ||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| @ -42,8 +39,6 @@ def clean_expired_models(self: SystemTask): | ||||
|     amount = 0 | ||||
|  | ||||
|     for session in AuthenticatedSession.objects.all(): | ||||
|         match CONFIG.get("session_storage", "cache"): | ||||
|             case "cache": | ||||
|         cache_key = f"{KEY_PREFIX}{session.session_key}" | ||||
|         value = None | ||||
|         try: | ||||
| @ -54,19 +49,6 @@ def clean_expired_models(self: SystemTask): | ||||
|         if not value: | ||||
|             session.delete() | ||||
|             amount += 1 | ||||
|             case "db": | ||||
|                 if not ( | ||||
|                     DBSessionStore.get_model_class() | ||||
|                     .objects.filter(session_key=session.session_key, expire_date__gt=now()) | ||||
|                     .exists() | ||||
|                 ): | ||||
|                     session.delete() | ||||
|                     amount += 1 | ||||
|             case _: | ||||
|                 # Should never happen, as we check for other values in authentik/root/settings.py | ||||
|                 raise ImproperlyConfigured( | ||||
|                     "Invalid session_storage setting, allowed values are db and cache" | ||||
|                 ) | ||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||
|  | ||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||
|  | ||||
| @ -1,11 +1,10 @@ | ||||
| """Test Groups API""" | ||||
|  | ||||
| from django.urls.base import reverse | ||||
| from guardian.shortcuts import assign_perm | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| @ -13,22 +12,13 @@ class TestGroupsAPI(APITestCase): | ||||
|     """Test Groups API""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.login_user = create_test_user() | ||||
|         self.admin = create_test_admin_user() | ||||
|         self.user = User.objects.create(username="test-user") | ||||
|  | ||||
|     def test_list_with_users(self): | ||||
|         """Test listing with users""" | ||||
|         admin = create_test_admin_user() | ||||
|         self.client.force_login(admin) | ||||
|         response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"}) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_add_user(self): | ||||
|         """Test add_user""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         assign_perm("authentik_core.add_user_to_group", self.login_user, group) | ||||
|         assign_perm("authentik_core.view_user", self.login_user) | ||||
|         self.client.force_login(self.login_user) | ||||
|         self.client.force_login(self.admin) | ||||
|         res = self.client.post( | ||||
|             reverse("authentik_api:group-add-user", kwargs={"pk": group.pk}), | ||||
|             data={ | ||||
| @ -42,9 +32,7 @@ class TestGroupsAPI(APITestCase): | ||||
|     def test_add_user_404(self): | ||||
|         """Test add_user""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         assign_perm("authentik_core.add_user_to_group", self.login_user, group) | ||||
|         assign_perm("authentik_core.view_user", self.login_user) | ||||
|         self.client.force_login(self.login_user) | ||||
|         self.client.force_login(self.admin) | ||||
|         res = self.client.post( | ||||
|             reverse("authentik_api:group-add-user", kwargs={"pk": group.pk}), | ||||
|             data={ | ||||
| @ -56,10 +44,8 @@ class TestGroupsAPI(APITestCase): | ||||
|     def test_remove_user(self): | ||||
|         """Test remove_user""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         assign_perm("authentik_core.remove_user_from_group", self.login_user, group) | ||||
|         assign_perm("authentik_core.view_user", self.login_user) | ||||
|         group.users.add(self.user) | ||||
|         self.client.force_login(self.login_user) | ||||
|         self.client.force_login(self.admin) | ||||
|         res = self.client.post( | ||||
|             reverse("authentik_api:group-remove-user", kwargs={"pk": group.pk}), | ||||
|             data={ | ||||
| @ -73,10 +59,8 @@ class TestGroupsAPI(APITestCase): | ||||
|     def test_remove_user_404(self): | ||||
|         """Test remove_user""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         assign_perm("authentik_core.remove_user_from_group", self.login_user, group) | ||||
|         assign_perm("authentik_core.view_user", self.login_user) | ||||
|         group.users.add(self.user) | ||||
|         self.client.force_login(self.login_user) | ||||
|         self.client.force_login(self.admin) | ||||
|         res = self.client.post( | ||||
|             reverse("authentik_api:group-remove-user", kwargs={"pk": group.pk}), | ||||
|             data={ | ||||
| @ -88,12 +72,11 @@ class TestGroupsAPI(APITestCase): | ||||
|     def test_parent_self(self): | ||||
|         """Test parent""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         assign_perm("view_group", self.login_user, group) | ||||
|         assign_perm("change_group", self.login_user, group) | ||||
|         self.client.force_login(self.login_user) | ||||
|         self.client.force_login(self.admin) | ||||
|         res = self.client.patch( | ||||
|             reverse("authentik_api:group-detail", kwargs={"pk": group.pk}), | ||||
|             data={ | ||||
|                 "pk": self.user.pk + 3, | ||||
|                 "parent": group.pk, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -2,15 +2,11 @@ | ||||
|  | ||||
| from django.contrib.auth.models import AnonymousUser | ||||
| from django.test import TestCase | ||||
| from django.urls import reverse | ||||
| from guardian.utils import get_anonymous_user | ||||
|  | ||||
| from authentik.core.models import SourceUserMatchingModes, User | ||||
| from authentik.core.sources.flow_manager import Action | ||||
| from authentik.core.sources.stage import PostSourceStage | ||||
| from authentik.core.tests.utils import create_test_flow | ||||
| from authentik.flows.planner import FlowPlan | ||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.lib.tests.utils import get_request | ||||
| from authentik.policies.denied import AccessDeniedResponse | ||||
| @ -25,62 +21,42 @@ class TestSourceFlowManager(TestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.authentication_flow = create_test_flow() | ||||
|         self.enrollment_flow = create_test_flow() | ||||
|         self.source: OAuthSource = OAuthSource.objects.create( | ||||
|             name=generate_id(), | ||||
|             slug=generate_id(), | ||||
|             authentication_flow=self.authentication_flow, | ||||
|             enrollment_flow=self.enrollment_flow, | ||||
|         ) | ||||
|         self.source: OAuthSource = OAuthSource.objects.create(name="test") | ||||
|         self.identifier = generate_id() | ||||
|  | ||||
|     def test_unauthenticated_enroll(self): | ||||
|         """Test un-authenticated user enrolling""" | ||||
|         request = get_request("/", user=AnonymousUser()) | ||||
|         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) | ||||
|         flow_manager = OAuthSourceFlowManager( | ||||
|             self.source, get_request("/", user=AnonymousUser()), self.identifier, {} | ||||
|         ) | ||||
|         action, _ = flow_manager.get_action() | ||||
|         self.assertEqual(action, Action.ENROLL) | ||||
|         response = flow_manager.get_flow() | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] | ||||
|         self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) | ||||
|         flow_manager.get_flow() | ||||
|  | ||||
|     def test_unauthenticated_auth(self): | ||||
|         """Test un-authenticated user authenticating""" | ||||
|         UserOAuthSourceConnection.objects.create( | ||||
|             user=get_anonymous_user(), source=self.source, identifier=self.identifier | ||||
|         ) | ||||
|         request = get_request("/", user=AnonymousUser()) | ||||
|         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) | ||||
|  | ||||
|         flow_manager = OAuthSourceFlowManager( | ||||
|             self.source, get_request("/", user=AnonymousUser()), self.identifier, {} | ||||
|         ) | ||||
|         action, _ = flow_manager.get_action() | ||||
|         self.assertEqual(action, Action.AUTH) | ||||
|         response = flow_manager.get_flow() | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] | ||||
|         self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) | ||||
|         flow_manager.get_flow() | ||||
|  | ||||
|     def test_authenticated_link(self): | ||||
|         """Test authenticated user linking""" | ||||
|         user = User.objects.create(username="foo", email="foo@bar.baz") | ||||
|         request = get_request("/", user=user) | ||||
|         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) | ||||
|         action, connection = flow_manager.get_action() | ||||
|         self.assertEqual(action, Action.LINK) | ||||
|         self.assertIsNone(connection.pk) | ||||
|         response = flow_manager.get_flow() | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         self.assertEqual( | ||||
|             response.url, | ||||
|             reverse("authentik_core:if-user") + "#/settings;page-sources", | ||||
|         UserOAuthSourceConnection.objects.create( | ||||
|             user=get_anonymous_user(), source=self.source, identifier=self.identifier | ||||
|         ) | ||||
|  | ||||
|     def test_unauthenticated_link(self): | ||||
|         """Test un-authenticated user linking""" | ||||
|         flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {}) | ||||
|         action, connection = flow_manager.get_action() | ||||
|         user = User.objects.create(username="foo", email="foo@bar.baz") | ||||
|         flow_manager = OAuthSourceFlowManager( | ||||
|             self.source, get_request("/", user=user), self.identifier, {} | ||||
|         ) | ||||
|         action, _ = flow_manager.get_action() | ||||
|         self.assertEqual(action, Action.LINK) | ||||
|         self.assertIsNone(connection.pk) | ||||
|         flow_manager.get_flow() | ||||
|  | ||||
|     def test_unauthenticated_enroll_email(self): | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| """Test token API""" | ||||
|  | ||||
| from datetime import datetime, timedelta | ||||
| from json import loads | ||||
|  | ||||
| from django.urls.base import reverse | ||||
| @ -8,13 +7,8 @@ from guardian.shortcuts import get_anonymous_user | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.api.tokens import TokenSerializer | ||||
| from authentik.core.models import ( | ||||
|     USER_ATTRIBUTE_TOKEN_EXPIRING, | ||||
|     USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, | ||||
|     Token, | ||||
|     TokenIntents, | ||||
| ) | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| @ -23,7 +17,7 @@ class TestTokenAPI(APITestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.user = create_test_user() | ||||
|         self.user = User.objects.create(username="testuser") | ||||
|         self.admin = create_test_admin_user() | ||||
|         self.client.force_login(self.user) | ||||
|  | ||||
| @ -82,95 +76,6 @@ class TestTokenAPI(APITestCase): | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, False) | ||||
|  | ||||
|     def test_token_create_expiring(self): | ||||
|         """Test token creation endpoint""" | ||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True | ||||
|         self.user.save() | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:token-list"), {"identifier": "test-token"} | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         token = Token.objects.get(identifier="test-token") | ||||
|         self.assertEqual(token.user, self.user) | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, True) | ||||
|  | ||||
|     def test_token_create_expiring_custom_ok(self): | ||||
|         """Test token creation endpoint""" | ||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True | ||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME] = "hours=2" | ||||
|         self.user.save() | ||||
|         expires = datetime.now() + timedelta(hours=1) | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:token-list"), | ||||
|             { | ||||
|                 "identifier": "test-token", | ||||
|                 "expires": expires, | ||||
|                 "intent": TokenIntents.INTENT_APP_PASSWORD, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         token = Token.objects.get(identifier="test-token") | ||||
|         self.assertEqual(token.user, self.user) | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_APP_PASSWORD) | ||||
|         self.assertEqual(token.expiring, True) | ||||
|         self.assertEqual(token.expires.timestamp(), expires.timestamp()) | ||||
|  | ||||
|     def test_token_create_expiring_custom_nok(self): | ||||
|         """Test token creation endpoint""" | ||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True | ||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME] = "hours=2" | ||||
|         self.user.save() | ||||
|         expires = datetime.now() + timedelta(hours=3) | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:token-list"), | ||||
|             { | ||||
|                 "identifier": "test-token", | ||||
|                 "expires": expires, | ||||
|                 "intent": TokenIntents.INTENT_APP_PASSWORD, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|  | ||||
|     def test_token_create_expiring_custom_api(self): | ||||
|         """Test token creation endpoint""" | ||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_EXPIRING] = True | ||||
|         self.user.attributes[USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME] = "hours=2" | ||||
|         self.user.save() | ||||
|         expires = datetime.now() + timedelta(seconds=3) | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:token-list"), | ||||
|             { | ||||
|                 "identifier": "test-token", | ||||
|                 "expires": expires, | ||||
|                 "intent": TokenIntents.INTENT_API, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         token = Token.objects.get(identifier="test-token") | ||||
|         self.assertEqual(token.user, self.user) | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, True) | ||||
|         self.assertNotEqual(token.expires.timestamp(), expires.timestamp()) | ||||
|  | ||||
|     def test_token_change_user(self): | ||||
|         """Test creating a token and then changing the user""" | ||||
|         ident = generate_id() | ||||
|         response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident}) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         token = Token.objects.get(identifier=ident) | ||||
|         self.assertEqual(token.user, self.user) | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, True) | ||||
|         self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token)) | ||||
|         response = self.client.put( | ||||
|             reverse("authentik_api:token-detail", kwargs={"identifier": ident}), | ||||
|             data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk}, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         token.refresh_from_db() | ||||
|         self.assertEqual(token.user, self.user) | ||||
|  | ||||
|     def test_list(self): | ||||
|         """Test Token List (Test normal authentication)""" | ||||
|         Token.objects.all().delete() | ||||
|  | ||||
| @ -41,12 +41,6 @@ class TestUsersAPI(APITestCase): | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_list_with_groups(self): | ||||
|         """Test listing with groups""" | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"}) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_metrics(self): | ||||
|         """Test user's metrics""" | ||||
|         self.client.force_login(self.admin) | ||||
|  | ||||
| @ -8,6 +8,7 @@ from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.tenants.utils import get_current_tenant | ||||
|  | ||||
|  | ||||
| @ -24,6 +25,7 @@ class TestUsersAvatars(APITestCase): | ||||
|         tenant.avatars = mode | ||||
|         tenant.save() | ||||
|  | ||||
|     @CONFIG.patch("avatars", "none") | ||||
|     def test_avatars_none(self): | ||||
|         """Test avatars none""" | ||||
|         self.set_avatar_mode("none") | ||||
|  | ||||
| @ -4,7 +4,7 @@ from django.utils.text import slugify | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg | ||||
| from authentik.crypto.builder import CertificateBuilder | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.flows.models import Flow, FlowDesignation | ||||
| from authentik.lib.generators import generate_id | ||||
| @ -50,10 +50,12 @@ def create_test_brand(**kwargs) -> Brand: | ||||
|     return Brand.objects.create(domain=uid, default=True, **kwargs) | ||||
|  | ||||
|  | ||||
| def create_test_cert(alg=PrivateKeyAlg.RSA) -> CertificateKeyPair: | ||||
| def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair: | ||||
|     """Generate a certificate for testing""" | ||||
|     builder = CertificateBuilder(f"{generate_id()}.self-signed.goauthentik.io") | ||||
|     builder.alg = alg | ||||
|     builder = CertificateBuilder( | ||||
|         name=f"{generate_id()}.self-signed.goauthentik.io", | ||||
|         use_ec_private_key=use_ec_private_key, | ||||
|     ) | ||||
|     builder.build( | ||||
|         subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"], | ||||
|         validity_days=360, | ||||
|  | ||||
| @ -6,7 +6,6 @@ from django.conf import settings | ||||
| from django.contrib.auth.decorators import login_required | ||||
| from django.urls import path | ||||
| from django.views.decorators.csrf import ensure_csrf_cookie | ||||
| from django.views.generic import RedirectView | ||||
|  | ||||
| from authentik.core.api.applications import ApplicationViewSet | ||||
| from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet | ||||
| @ -20,7 +19,12 @@ from authentik.core.api.transactional_applications import TransactionalApplicati | ||||
| from authentik.core.api.users import UserViewSet | ||||
| from authentik.core.views import apps | ||||
| from authentik.core.views.debug import AccessDeniedView | ||||
| from authentik.core.views.interface import FlowInterfaceView, InterfaceView | ||||
| from authentik.core.views.interface import ( | ||||
|     BrandDefaultRedirectView, | ||||
|     FlowInterfaceView, | ||||
|     InterfaceView, | ||||
|     RootRedirectView, | ||||
| ) | ||||
| from authentik.core.views.session import EndSessionView | ||||
| from authentik.root.asgi_middleware import SessionMiddleware | ||||
| from authentik.root.messages.consumer import MessageConsumer | ||||
| @ -29,13 +33,11 @@ from authentik.root.middleware import ChannelsLoggingMiddleware | ||||
| urlpatterns = [ | ||||
|     path( | ||||
|         "", | ||||
|         login_required( | ||||
|             RedirectView.as_view(pattern_name="authentik_core:if-user", query_string=True) | ||||
|         ), | ||||
|         login_required(RootRedirectView.as_view()), | ||||
|         name="root-redirect", | ||||
|     ), | ||||
|     path( | ||||
|         # We have to use this format since everything else uses applications/o or applications/saml | ||||
|         # We have to use this format since everything else uses application/o or application/saml | ||||
|         "application/launch/<slug:application_slug>/", | ||||
|         apps.RedirectToAppLaunch.as_view(), | ||||
|         name="application-launch", | ||||
| @ -43,12 +45,12 @@ urlpatterns = [ | ||||
|     # Interfaces | ||||
|     path( | ||||
|         "if/admin/", | ||||
|         ensure_csrf_cookie(InterfaceView.as_view(template_name="if/admin.html")), | ||||
|         ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/admin.html")), | ||||
|         name="if-admin", | ||||
|     ), | ||||
|     path( | ||||
|         "if/user/", | ||||
|         ensure_csrf_cookie(InterfaceView.as_view(template_name="if/user.html")), | ||||
|         ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/user.html")), | ||||
|         name="if-user", | ||||
|     ), | ||||
|     path( | ||||
|  | ||||
| @ -3,15 +3,43 @@ | ||||
| from json import dumps | ||||
| from typing import Any | ||||
|  | ||||
| from django.shortcuts import get_object_or_404 | ||||
| from django.views.generic.base import TemplateView | ||||
| from django.http import HttpRequest | ||||
| from django.http.response import HttpResponse | ||||
| from django.shortcuts import get_object_or_404, redirect | ||||
| from django.utils.translation import gettext as _ | ||||
| from django.views.generic.base import RedirectView, TemplateView | ||||
| from rest_framework.request import Request | ||||
|  | ||||
| from authentik import get_build_hash | ||||
| from authentik.admin.tasks import LOCAL_VERSION | ||||
| from authentik.api.v3.config import ConfigView | ||||
| from authentik.brands.api import CurrentBrandSerializer | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import UserTypes | ||||
| from authentik.flows.models import Flow | ||||
| from authentik.policies.denied import AccessDeniedResponse | ||||
|  | ||||
|  | ||||
| class RootRedirectView(RedirectView): | ||||
|     """Root redirect view, redirect to brand's default application if set""" | ||||
|  | ||||
|     pattern_name = "authentik_core:if-user" | ||||
|     query_string = True | ||||
|  | ||||
|     def redirect_to_app(self, request: HttpRequest): | ||||
|         if request.user.is_authenticated and request.user.type == UserTypes.EXTERNAL: | ||||
|             brand: Brand = request.brand | ||||
|             if brand.default_application: | ||||
|                 return redirect( | ||||
|                     "authentik_core:application-launch", | ||||
|                     application_slug=brand.default_application.slug, | ||||
|                 ) | ||||
|         return None | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||
|         if redirect_response := RootRedirectView().redirect_to_app(request): | ||||
|             return redirect_response | ||||
|         return super().dispatch(request, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class InterfaceView(TemplateView): | ||||
| @ -27,6 +55,22 @@ class InterfaceView(TemplateView): | ||||
|         return super().get_context_data(**kwargs) | ||||
|  | ||||
|  | ||||
| class BrandDefaultRedirectView(InterfaceView): | ||||
|     """By default redirect to default app""" | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||
|         if request.user.is_authenticated and request.user.type == UserTypes.EXTERNAL: | ||||
|             brand: Brand = request.brand | ||||
|             if brand.default_application: | ||||
|                 return redirect( | ||||
|                     "authentik_core:application-launch", | ||||
|                     application_slug=brand.default_application.slug, | ||||
|                 ) | ||||
|             response = AccessDeniedResponse(self.request) | ||||
|             response.error_message = _("Interface can only be accessed by internal users.") | ||||
|         return super().dispatch(request, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class FlowInterfaceView(InterfaceView): | ||||
|     """Flow interface""" | ||||
|  | ||||
|  | ||||
| @ -14,13 +14,7 @@ from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import ( | ||||
|     CharField, | ||||
|     ChoiceField, | ||||
|     DateTimeField, | ||||
|     IntegerField, | ||||
|     SerializerMethodField, | ||||
| ) | ||||
| from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField | ||||
| from rest_framework.filters import OrderingFilter, SearchFilter | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| @ -32,7 +26,7 @@ from authentik.api.authorization import SecretKeyFilter | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.crypto.apps import MANAGED_KEY | ||||
| from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg | ||||
| from authentik.crypto.builder import CertificateBuilder | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.rbac.decorators import permission_required | ||||
| @ -184,7 +178,6 @@ class CertificateGenerationSerializer(PassiveSerializer): | ||||
|     common_name = CharField() | ||||
|     subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name")) | ||||
|     validity_days = IntegerField(initial=365) | ||||
|     alg = ChoiceField(default=PrivateKeyAlg.RSA, choices=PrivateKeyAlg.choices) | ||||
|  | ||||
|  | ||||
| class CertificateKeyPairFilter(FilterSet): | ||||
| @ -247,7 +240,6 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | ||||
|         raw_san = data.validated_data.get("subject_alt_name", "") | ||||
|         sans = raw_san.split(",") if raw_san != "" else [] | ||||
|         builder = CertificateBuilder(data.validated_data["common_name"]) | ||||
|         builder.alg = data.validated_data["alg"] | ||||
|         builder.build( | ||||
|             subject_alt_names=sans, | ||||
|             validity_days=int(data.validated_data["validity_days"]), | ||||
|  | ||||
| @ -9,28 +9,20 @@ from cryptography.hazmat.primitives import hashes, serialization | ||||
| from cryptography.hazmat.primitives.asymmetric import ec, rsa | ||||
| from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | ||||
| from cryptography.x509.oid import NameOID | ||||
| from django.db import models | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| from authentik import __version__ | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
|  | ||||
|  | ||||
| class PrivateKeyAlg(models.TextChoices): | ||||
|     """Algorithm to create private key with""" | ||||
|  | ||||
|     RSA = "rsa", _("rsa") | ||||
|     ECDSA = "ecdsa", _("ecdsa") | ||||
|  | ||||
|  | ||||
| class CertificateBuilder: | ||||
|     """Build self-signed certificates""" | ||||
|  | ||||
|     common_name: str | ||||
|     alg: PrivateKeyAlg | ||||
|  | ||||
|     def __init__(self, name: str): | ||||
|         self.alg = PrivateKeyAlg.RSA | ||||
|     _use_ec_private_key: bool | ||||
|  | ||||
|     def __init__(self, name: str, use_ec_private_key=False): | ||||
|         self._use_ec_private_key = use_ec_private_key | ||||
|         self.__public_key = None | ||||
|         self.__private_key = None | ||||
|         self.__builder = None | ||||
| @ -50,13 +42,11 @@ class CertificateBuilder: | ||||
|  | ||||
|     def generate_private_key(self) -> PrivateKeyTypes: | ||||
|         """Generate private key""" | ||||
|         if self.alg == PrivateKeyAlg.ECDSA: | ||||
|         if self._use_ec_private_key: | ||||
|             return ec.generate_private_key(curve=ec.SECP256R1()) | ||||
|         if self.alg == PrivateKeyAlg.RSA: | ||||
|         return rsa.generate_private_key( | ||||
|             public_exponent=65537, key_size=4096, backend=default_backend() | ||||
|         ) | ||||
|         raise ValueError(f"Invalid alg: {self.alg}") | ||||
|  | ||||
|     def build( | ||||
|         self, | ||||
|  | ||||
| @ -13,9 +13,9 @@ class AuthentikEnterpriseAuditConfig(EnterpriseConfig): | ||||
|     verbose_name = "authentik Enterprise.Audit" | ||||
|     default = True | ||||
|  | ||||
|     def ready(self): | ||||
|     @EnterpriseConfig.reconcile_global | ||||
|     def install_middleware(self): | ||||
|         """Install enterprise audit middleware""" | ||||
|         orig_import = "authentik.events.middleware.AuditMiddleware" | ||||
|         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" | ||||
|         settings.MIDDLEWARE = [new_import if x == orig_import else x for x in settings.MIDDLEWARE] | ||||
|         return super().ready() | ||||
|  | ||||
| @ -2,16 +2,16 @@ | ||||
|  | ||||
| from copy import deepcopy | ||||
| from functools import partial | ||||
| from typing import Any | ||||
|  | ||||
| from django.apps.registry import apps | ||||
| from django.core.files import File | ||||
| from django.db import connection | ||||
| from django.db.models import ManyToManyRel, Model | ||||
| from django.db.models import Model | ||||
| from django.db.models.expressions import BaseExpression, Combinable | ||||
| from django.db.models.signals import post_init | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.events.middleware import AuditMiddleware, should_log_model | ||||
| from authentik.events.utils import cleanse_dict, sanitize_item | ||||
|  | ||||
| @ -28,10 +28,13 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|         super().connect(request) | ||||
|         if not self.enabled: | ||||
|             return | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             user = self.anonymous_user | ||||
|         if not hasattr(request, "request_id"): | ||||
|             return | ||||
|         post_init.connect( | ||||
|             partial(self.post_init_handler, request=request), | ||||
|             partial(self.post_init_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
| @ -45,7 +48,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|         post_init.disconnect(dispatch_uid=request.request_id) | ||||
|  | ||||
|     def serialize_simple(self, model: Model) -> dict: | ||||
|         """Serialize a model in a very simple way. No ForeignKeys or other relationships are | ||||
|         """Serialize a model in a very simple way. No ForeginKeys or other relationships are | ||||
|         resolved""" | ||||
|         data = {} | ||||
|         deferred_fields = model.get_deferred_fields() | ||||
| @ -71,12 +74,9 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|         for key, value in before.items(): | ||||
|             if after.get(key) != value: | ||||
|                 diff[key] = {"previous_value": value, "new_value": after.get(key)} | ||||
|         for key, value in after.items(): | ||||
|             if key not in before and key not in diff and before.get(key) != value: | ||||
|                 diff[key] = {"previous_value": before.get(key), "new_value": value} | ||||
|         return sanitize_item(diff) | ||||
|  | ||||
|     def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||
|     def post_init_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||
|         """post_init django model handler""" | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
| @ -90,6 +90,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|  | ||||
|     def post_save_handler( | ||||
|         self, | ||||
|         user: User, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
| @ -102,37 +103,15 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|         thread_kwargs = {} | ||||
|         if hasattr(instance, "_previous_state") or created: | ||||
|             prev_state = getattr(instance, "_previous_state", {}) | ||||
|             if created: | ||||
|                 prev_state = {} | ||||
|             # Get current state | ||||
|             new_state = self.serialize_simple(instance) | ||||
|             diff = self.diff(prev_state, new_state) | ||||
|             thread_kwargs["diff"] = diff | ||||
|         return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) | ||||
|  | ||||
|     def m2m_changed_handler(  # noqa: PLR0913 | ||||
|         self, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
|         action: str, | ||||
|         pk_set: set[Any], | ||||
|         thread_kwargs: dict | None = None, | ||||
|         **_, | ||||
|     ): | ||||
|         thread_kwargs = {} | ||||
|         m2m_field = None | ||||
|         # For the audit log we don't care about `pre_` or `post_` so we trim that part off | ||||
|         _, _, action_direction = action.partition("_") | ||||
|         # resolve the "through" model to an actual field | ||||
|         for field in instance._meta.get_fields(): | ||||
|             if not isinstance(field, ManyToManyRel): | ||||
|                 continue | ||||
|             if field.through == sender: | ||||
|                 m2m_field = field | ||||
|         if m2m_field: | ||||
|             # If we're clearing we just set the "flag" to True | ||||
|             if action_direction == "clear": | ||||
|                 pk_set = True | ||||
|             thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}} | ||||
|         return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs) | ||||
|             if not created: | ||||
|                 ignored_field_sets = getattr(instance._meta, "authentik_signals_ignored_fields", []) | ||||
|                 for field_set in ignored_field_sets: | ||||
|                     if set(diff.keys()) == set(field_set): | ||||
|                         return None | ||||
|         return super().post_save_handler( | ||||
|             user, request, sender, instance, created, thread_kwargs, **_ | ||||
|         ) | ||||
|  | ||||
| @ -1,210 +0,0 @@ | ||||
| from unittest.mock import PropertyMock, patch | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.conf import settings | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.utils import sanitize_item | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| class TestEnterpriseAudit(APITestCase): | ||||
|     """Test audit middleware""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.user = create_test_admin_user() | ||||
|  | ||||
|     def test_import(self): | ||||
|         """Ensure middleware is imported when app.ready is called""" | ||||
|         # Revert import swap | ||||
|         orig_import = "authentik.events.middleware.AuditMiddleware" | ||||
|         new_import = "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware" | ||||
|         settings.MIDDLEWARE = [orig_import if x == new_import else x for x in settings.MIDDLEWARE] | ||||
|         # Re-call ready() | ||||
|         apps.get_app_config("authentik_enterprise_audit").ready() | ||||
|         self.assertIn( | ||||
|             "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||
|         PropertyMock(return_value=True), | ||||
|     ) | ||||
|     def test_create(self): | ||||
|         """Test create audit log""" | ||||
|         self.client.force_login(self.user) | ||||
|         username = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:user-list"), | ||||
|             data={"name": generate_id(), "username": username, "groups": [], "path": "foo"}, | ||||
|         ) | ||||
|         user = User.objects.get(username=username) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         events = Event.objects.filter( | ||||
|             action=EventAction.MODEL_CREATED, | ||||
|             context__model__model_name="user", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__pk=user.pk, | ||||
|         ) | ||||
|         event = events.first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertIsNotNone(event.context["diff"]) | ||||
|         diff = event.context["diff"] | ||||
|         self.assertEqual( | ||||
|             diff, | ||||
|             { | ||||
|                 "name": { | ||||
|                     "new_value": user.name, | ||||
|                     "previous_value": None, | ||||
|                 }, | ||||
|                 "path": {"new_value": "foo", "previous_value": None}, | ||||
|                 "type": {"new_value": "internal", "previous_value": None}, | ||||
|                 "uuid": { | ||||
|                     "new_value": user.uuid.hex, | ||||
|                     "previous_value": None, | ||||
|                 }, | ||||
|                 "email": {"new_value": "", "previous_value": None}, | ||||
|                 "username": { | ||||
|                     "new_value": user.username, | ||||
|                     "previous_value": None, | ||||
|                 }, | ||||
|                 "is_active": {"new_value": True, "previous_value": None}, | ||||
|                 "attributes": {"new_value": {}, "previous_value": None}, | ||||
|                 "date_joined": { | ||||
|                     "new_value": sanitize_item(user.date_joined), | ||||
|                     "previous_value": None, | ||||
|                 }, | ||||
|                 "first_name": {"new_value": "", "previous_value": None}, | ||||
|                 "id": {"new_value": user.pk, "previous_value": None}, | ||||
|                 "last_name": {"new_value": "", "previous_value": None}, | ||||
|                 "password": {"new_value": "********************", "previous_value": None}, | ||||
|                 "password_change_date": { | ||||
|                     "new_value": sanitize_item(user.password_change_date), | ||||
|                     "previous_value": None, | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||
|         PropertyMock(return_value=True), | ||||
|     ) | ||||
|     def test_update(self): | ||||
|         """Test update audit log""" | ||||
|         self.client.force_login(self.user) | ||||
|         user = create_test_admin_user() | ||||
|         current_name = user.name | ||||
|         new_name = generate_id() | ||||
|         response = self.client.patch( | ||||
|             reverse("authentik_api:user-detail", kwargs={"pk": user.id}), | ||||
|             data={"name": new_name}, | ||||
|         ) | ||||
|         user.refresh_from_db() | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         events = Event.objects.filter( | ||||
|             action=EventAction.MODEL_UPDATED, | ||||
|             context__model__model_name="user", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__pk=user.pk, | ||||
|         ) | ||||
|         event = events.first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertIsNotNone(event.context["diff"]) | ||||
|         diff = event.context["diff"] | ||||
|         self.assertEqual( | ||||
|             diff, | ||||
|             { | ||||
|                 "name": { | ||||
|                     "new_value": new_name, | ||||
|                     "previous_value": current_name, | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||
|         PropertyMock(return_value=True), | ||||
|     ) | ||||
|     def test_delete(self): | ||||
|         """Test delete audit log""" | ||||
|         self.client.force_login(self.user) | ||||
|         user = create_test_admin_user() | ||||
|         response = self.client.delete( | ||||
|             reverse("authentik_api:user-detail", kwargs={"pk": user.id}), | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 204) | ||||
|         events = Event.objects.filter( | ||||
|             action=EventAction.MODEL_DELETED, | ||||
|             context__model__model_name="user", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__pk=user.pk, | ||||
|         ) | ||||
|         event = events.first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertNotIn("diff", event.context) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||
|         PropertyMock(return_value=True), | ||||
|     ) | ||||
|     def test_m2m_add(self): | ||||
|         """Test m2m add audit log""" | ||||
|         self.client.force_login(self.user) | ||||
|         user = create_test_admin_user() | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:group-add-user", kwargs={"pk": group.group_uuid}), | ||||
|             data={ | ||||
|                 "pk": user.pk, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 204) | ||||
|         events = Event.objects.filter( | ||||
|             action=EventAction.MODEL_UPDATED, | ||||
|             context__model__model_name="group", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__pk=group.pk.hex, | ||||
|         ) | ||||
|         event = events.first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertIsNotNone(event.context["diff"]) | ||||
|         diff = event.context["diff"] | ||||
|         self.assertEqual( | ||||
|             diff, | ||||
|             {"users": {"add": [user.pk]}}, | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||
|         PropertyMock(return_value=True), | ||||
|     ) | ||||
|     def test_m2m_remove(self): | ||||
|         """Test m2m remove audit log""" | ||||
|         self.client.force_login(self.user) | ||||
|         user = create_test_admin_user() | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_api:group-remove-user", kwargs={"pk": group.group_uuid}), | ||||
|             data={ | ||||
|                 "pk": user.pk, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 204) | ||||
|         events = Event.objects.filter( | ||||
|             action=EventAction.MODEL_UPDATED, | ||||
|             context__model__model_name="group", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__pk=group.pk.hex, | ||||
|         ) | ||||
|         event = events.first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertIsNotNone(event.context["diff"]) | ||||
|         diff = event.context["diff"] | ||||
|         self.assertEqual( | ||||
|             diff, | ||||
|             {"users": {"remove": [user.pk]}}, | ||||
|         ) | ||||
| @ -1,18 +0,0 @@ | ||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_rac", "0001_squashed_0003_alter_connectiontoken_options_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterField( | ||||
|             model_name="connectiontoken", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|     ] | ||||
| @ -201,7 +201,10 @@ class ConnectionToken(ExpiringModel): | ||||
|         return settings | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"RAC Connection token {self.session_id} to {self.provider_id}/{self.endpoint_id}" | ||||
|         return ( | ||||
|             f"RAC Connection token {self.session.user} to " | ||||
|             f"{self.endpoint.provider.name}/{self.endpoint.name}" | ||||
|         ) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("RAC Connection token") | ||||
|  | ||||
| @ -12,6 +12,7 @@ from rest_framework.fields import ( | ||||
|     ChoiceField, | ||||
|     DateTimeField, | ||||
|     FloatField, | ||||
|     ListField, | ||||
|     SerializerMethodField, | ||||
| ) | ||||
| from rest_framework.request import Request | ||||
| @ -20,7 +21,6 @@ from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import ReadOnlyModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.events.logs import LogEventSerializer | ||||
| from authentik.events.models import SystemTask, TaskStatus | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| @ -39,7 +39,7 @@ class SystemTaskSerializer(ModelSerializer): | ||||
|     duration = FloatField(read_only=True) | ||||
|  | ||||
|     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) | ||||
|     messages = LogEventSerializer(many=True) | ||||
|     messages = ListField(child=CharField()) | ||||
|  | ||||
|     def get_full_name(self, instance: SystemTask) -> str: | ||||
|         """Get full name with UID""" | ||||
|  | ||||
| @ -1,82 +0,0 @@ | ||||
| from collections.abc import Generator | ||||
| from contextlib import contextmanager | ||||
| from dataclasses import dataclass, field | ||||
| from datetime import datetime | ||||
| from typing import Any | ||||
|  | ||||
| from django.utils.timezone import now | ||||
| from rest_framework.fields import CharField, ChoiceField, DateTimeField, DictField | ||||
| from structlog import configure, get_config | ||||
| from structlog.stdlib import NAME_TO_LEVEL, ProcessorFormatter | ||||
| from structlog.testing import LogCapture | ||||
| from structlog.types import EventDict | ||||
|  | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.events.utils import sanitize_dict | ||||
|  | ||||
|  | ||||
| @dataclass() | ||||
| class LogEvent: | ||||
|  | ||||
|     event: str | ||||
|     log_level: str | ||||
|     logger: str | ||||
|     timestamp: datetime = field(default_factory=now) | ||||
|     attributes: dict[str, Any] = field(default_factory=dict) | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_event_dict(item: EventDict) -> "LogEvent": | ||||
|         event = item.pop("event") | ||||
|         log_level = item.pop("level").lower() | ||||
|         timestamp = datetime.fromisoformat(item.pop("timestamp")) | ||||
|         item.pop("pid", None) | ||||
|         # Sometimes log entries have both `level` and `log_level` set, but `level` is always set | ||||
|         item.pop("log_level", None) | ||||
|         return LogEvent( | ||||
|             event, log_level, item.pop("logger"), timestamp, attributes=sanitize_dict(item) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class LogEventSerializer(PassiveSerializer): | ||||
|     """Single log message with all context logged.""" | ||||
|  | ||||
|     timestamp = DateTimeField() | ||||
|     log_level = ChoiceField(choices=tuple((x, x) for x in NAME_TO_LEVEL.keys())) | ||||
|     logger = CharField() | ||||
|     event = CharField() | ||||
|     attributes = DictField() | ||||
|  | ||||
|     # TODO(2024.6?): This is a migration helper to return a correct API response for logs that | ||||
|     # have been saved in an older format (mostly just list[str] with just the messages) | ||||
|     def to_representation(self, instance): | ||||
|         if isinstance(instance, str): | ||||
|             instance = LogEvent(instance, "", "") | ||||
|         elif isinstance(instance, list): | ||||
|             instance = [LogEvent(x, "", "") for x in instance] | ||||
|         return super().to_representation(instance) | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def capture_logs(log_default_output=True) -> Generator[list[LogEvent], None, None]: | ||||
|     """Capture log entries created""" | ||||
|     logs = [] | ||||
|     cap = LogCapture() | ||||
|     # Modify `_Configuration.default_processors` set via `configure` but always | ||||
|     # keep the list instance intact to not break references held by bound | ||||
|     # loggers. | ||||
|     processors: list = get_config()["processors"] | ||||
|     old_processors = processors.copy() | ||||
|     try: | ||||
|         # clear processors list and use LogCapture for testing | ||||
|         if ProcessorFormatter.wrap_for_formatter in processors: | ||||
|             processors.remove(ProcessorFormatter.wrap_for_formatter) | ||||
|         processors.append(cap) | ||||
|         configure(processors=processors) | ||||
|         yield logs | ||||
|         for raw_log in cap.entries: | ||||
|             logs.append(LogEvent.from_event_dict(raw_log)) | ||||
|     finally: | ||||
|         # remove LogCapture and restore original processors | ||||
|         processors.clear() | ||||
|         processors.extend(old_processors) | ||||
|         configure(processors=processors) | ||||
| @ -1,8 +1,6 @@ | ||||
| """Events middleware""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
| from contextlib import contextmanager | ||||
| from contextvars import ContextVar | ||||
| from functools import partial | ||||
| from threading import Thread | ||||
| from typing import Any | ||||
| @ -33,9 +31,6 @@ IGNORED_MODELS = tuple( | ||||
|     ) | ||||
| ) | ||||
|  | ||||
| _CTX_OVERWRITE_USER = ContextVar[User | None]("authentik_events_log_overwrite_user", default=None) | ||||
| _CTX_IGNORE = ContextVar[bool]("authentik_events_log_ignore", default=False) | ||||
|  | ||||
|  | ||||
| def should_log_model(model: Model) -> bool: | ||||
|     """Return true if operation on `model` should be logged""" | ||||
| @ -49,28 +44,6 @@ def should_log_m2m(model: Model) -> bool: | ||||
|     return False | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def audit_overwrite_user(user: User): | ||||
|     """Overwrite user being logged for model AuditMiddleware. Commonly used | ||||
|     for example in flows where a pending user is given, but the request is not authenticated yet""" | ||||
|     _CTX_OVERWRITE_USER.set(user) | ||||
|     try: | ||||
|         yield | ||||
|     finally: | ||||
|         _CTX_OVERWRITE_USER.set(None) | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def audit_ignore(): | ||||
|     """Ignore model operations in the block. Useful for objects which need to be modified | ||||
|     but are not excluded (e.g. WebAuthn devices)""" | ||||
|     _CTX_IGNORE.set(True) | ||||
|     try: | ||||
|         yield | ||||
|     finally: | ||||
|         _CTX_IGNORE.set(False) | ||||
|  | ||||
|  | ||||
| class EventNewThread(Thread): | ||||
|     """Create Event in background thread""" | ||||
|  | ||||
| @ -110,32 +83,26 @@ class AuditMiddleware: | ||||
|  | ||||
|         self.anonymous_user = get_anonymous_user() | ||||
|  | ||||
|     def get_user(self, request: HttpRequest) -> User: | ||||
|         user = _CTX_OVERWRITE_USER.get() | ||||
|         if user: | ||||
|             return user | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             self._ensure_fallback_user() | ||||
|             return self.anonymous_user | ||||
|         return user | ||||
|  | ||||
|     def connect(self, request: HttpRequest): | ||||
|         """Connect signal for automatic logging""" | ||||
|         self._ensure_fallback_user() | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             user = self.anonymous_user | ||||
|         if not hasattr(request, "request_id"): | ||||
|             return | ||||
|         post_save.connect( | ||||
|             partial(self.post_save_handler, request=request), | ||||
|             partial(self.post_save_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
|         pre_delete.connect( | ||||
|             partial(self.pre_delete_handler, request=request), | ||||
|             partial(self.pre_delete_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
|         m2m_changed.connect( | ||||
|             partial(self.m2m_changed_handler, request=request), | ||||
|             partial(self.m2m_changed_handler, user=user, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
| @ -180,6 +147,7 @@ class AuditMiddleware: | ||||
|  | ||||
|     def post_save_handler( | ||||
|         self, | ||||
|         user: User, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
| @ -190,22 +158,16 @@ class AuditMiddleware: | ||||
|         """Signal handler for all object's post_save""" | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
|         if _CTX_IGNORE.get(): | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||
|         thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) | ||||
|         thread.kwargs.update(thread_kwargs or {}) | ||||
|         thread.run() | ||||
|  | ||||
|     def pre_delete_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||
|     def pre_delete_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||
|         """Signal handler for all object's pre_delete""" | ||||
|         if not should_log_model(instance):  # pragma: no cover | ||||
|             return | ||||
|         if _CTX_IGNORE.get(): | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         EventNewThread( | ||||
|             EventAction.MODEL_DELETED, | ||||
| @ -215,27 +177,17 @@ class AuditMiddleware: | ||||
|         ).run() | ||||
|  | ||||
|     def m2m_changed_handler( | ||||
|         self, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
|         action: str, | ||||
|         thread_kwargs: dict | None = None, | ||||
|         **_, | ||||
|         self, user: User, request: HttpRequest, sender, instance: Model, action: str, **_ | ||||
|     ): | ||||
|         """Signal handler for all object's m2m_changed""" | ||||
|         if action not in ["pre_add", "pre_remove", "post_clear"]: | ||||
|             return | ||||
|         if not should_log_m2m(instance): | ||||
|             return | ||||
|         if _CTX_IGNORE.get(): | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         EventNewThread( | ||||
|             EventAction.MODEL_UPDATED, | ||||
|             request, | ||||
|             user=user, | ||||
|             model=model_to_dict(instance), | ||||
|             **thread_kwargs, | ||||
|         ).run() | ||||
|  | ||||
| @ -1,21 +0,0 @@ | ||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ( | ||||
|             "authentik_events", | ||||
|             "0004_systemtask_squashed_0005_remove_systemtask_finish_timestamp_and_more", | ||||
|         ), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterField( | ||||
|             model_name="systemtask", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,39 +0,0 @@ | ||||
| # Generated by Django 5.0.4 on 2024-04-15 16:17 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_events", "0006_alter_systemtask_expires"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddIndex( | ||||
|             model_name="event", | ||||
|             index=models.Index(fields=["action"], name="authentik_e_action_9a9dd9_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="event", | ||||
|             index=models.Index(fields=["user"], name="authentik_e_user_1be48d_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="event", | ||||
|             index=models.Index(fields=["app"], name="authentik_e_app_6a05ce_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="event", | ||||
|             index=models.Index(fields=["created"], name="authentik_e_created_6f0834_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="event", | ||||
|             index=models.Index(fields=["client_ip"], name="authentik_e_client__51f4dd_idx"), | ||||
|         ), | ||||
|         migrations.AddIndex( | ||||
|             model_name="event", | ||||
|             index=models.Index( | ||||
|                 models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -305,16 +305,6 @@ class Event(SerializerModel, ExpiringModel): | ||||
|     class Meta: | ||||
|         verbose_name = _("Event") | ||||
|         verbose_name_plural = _("Events") | ||||
|         indexes = [ | ||||
|             models.Index(fields=["action"]), | ||||
|             models.Index(fields=["user"]), | ||||
|             models.Index(fields=["app"]), | ||||
|             models.Index(fields=["created"]), | ||||
|             models.Index(fields=["client_ip"]), | ||||
|             models.Index( | ||||
|                 models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class TransportMode(models.TextChoices): | ||||
| @ -556,7 +546,7 @@ class Notification(SerializerModel): | ||||
|             if len(self.body) > NOTIFICATION_SUMMARY_LENGTH | ||||
|             else self.body | ||||
|         ) | ||||
|         return f"Notification for user {self.user_id}: {body_trunc}" | ||||
|         return f"Notification for user {self.user}: {body_trunc}" | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Notification") | ||||
|  | ||||
| @ -9,7 +9,6 @@ from django.utils.translation import gettext_lazy as _ | ||||
| from structlog.stdlib import get_logger | ||||
| from tenant_schemas_celery.task import TenantTask | ||||
|  | ||||
| from authentik.events.logs import LogEvent | ||||
| from authentik.events.models import Event, EventAction, TaskStatus | ||||
| from authentik.events.models import SystemTask as DBSystemTask | ||||
| from authentik.events.utils import sanitize_item | ||||
| @ -25,7 +24,7 @@ class SystemTask(TenantTask): | ||||
|     save_on_success: bool | ||||
|  | ||||
|     _status: TaskStatus | ||||
|     _messages: list[LogEvent] | ||||
|     _messages: list[str] | ||||
|  | ||||
|     _uid: str | None | ||||
|     # Precise start time from perf_counter | ||||
| @ -45,20 +44,15 @@ class SystemTask(TenantTask): | ||||
|         """Set UID, so in the case of an unexpected error its saved correctly""" | ||||
|         self._uid = uid | ||||
|  | ||||
|     def set_status(self, status: TaskStatus, *messages: LogEvent): | ||||
|     def set_status(self, status: TaskStatus, *messages: str): | ||||
|         """Set result for current run, will overwrite previous result.""" | ||||
|         self._status = status | ||||
|         self._messages = list(messages) | ||||
|         for idx, msg in enumerate(self._messages): | ||||
|             if not isinstance(msg, LogEvent): | ||||
|                 self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info") | ||||
|         self._messages = messages | ||||
|  | ||||
|     def set_error(self, exception: Exception): | ||||
|         """Set result to error and save exception""" | ||||
|         self._status = TaskStatus.ERROR | ||||
|         self._messages = [ | ||||
|             LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error") | ||||
|         ] | ||||
|         self._messages = [exception_to_string(exception)] | ||||
|  | ||||
|     def before_start(self, task_id, args, kwargs): | ||||
|         self._start_precise = perf_counter() | ||||
| @ -104,7 +98,8 @@ class SystemTask(TenantTask): | ||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||
|         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||
|         if not self._status: | ||||
|             self.set_error(exc) | ||||
|             self._status = TaskStatus.ERROR | ||||
|             self._messages = exception_to_string(exc) | ||||
|         DBSystemTask.objects.update_or_create( | ||||
|             name=self.__name__, | ||||
|             uid=self._uid, | ||||
| @ -119,7 +114,7 @@ class SystemTask(TenantTask): | ||||
|                 "task_call_kwargs": sanitize_item(kwargs), | ||||
|                 "status": self._status, | ||||
|                 "messages": sanitize_item(self._messages), | ||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours + 3), | ||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||
|                 "expiring": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -3,11 +3,9 @@ | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application, Token, TokenIntents | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.events.middleware import audit_ignore, audit_overwrite_user | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| class TestEventsMiddleware(APITestCase): | ||||
| @ -17,100 +15,35 @@ class TestEventsMiddleware(APITestCase): | ||||
|         super().setUp() | ||||
|         self.user = create_test_admin_user() | ||||
|         self.client.force_login(self.user) | ||||
|         Event.objects.all().delete() | ||||
|  | ||||
|     def test_create(self): | ||||
|         """Test model creation event""" | ||||
|         uid = generate_id() | ||||
|         self.client.post( | ||||
|             reverse("authentik_api:application-list"), | ||||
|             data={"name": uid, "slug": uid}, | ||||
|             data={"name": "test-create", "slug": "test-create"}, | ||||
|         ) | ||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) | ||||
|         event = Event.objects.filter( | ||||
|         self.assertTrue(Application.objects.filter(name="test-create").exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, | ||||
|                 context__model__model_name="application", | ||||
|                 context__model__app="authentik_core", | ||||
|             context__model__name=uid, | ||||
|         ).first() | ||||
|         self.assertIsNotNone(event) | ||||
|                 context__model__name="test-create", | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_delete(self): | ||||
|         """Test model creation event""" | ||||
|         uid = generate_id() | ||||
|         Application.objects.create(name=uid, slug=uid) | ||||
|         self.client.delete(reverse("authentik_api:application-detail", kwargs={"slug": uid})) | ||||
|         Application.objects.create(name="test-delete", slug="test-delete") | ||||
|         self.client.delete( | ||||
|             reverse("authentik_api:application-detail", kwargs={"slug": "test-delete"}) | ||||
|         ) | ||||
|         self.assertFalse(Application.objects.filter(name="test").exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_DELETED, | ||||
|                 context__model__model_name="application", | ||||
|                 context__model__app="authentik_core", | ||||
|                 context__model__name=uid, | ||||
|                 context__model__name="test-delete", | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_audit_ignore(self): | ||||
|         """Test audit_ignore context manager""" | ||||
|         uid = generate_id() | ||||
|         with audit_ignore(): | ||||
|             self.client.post( | ||||
|                 reverse("authentik_api:application-list"), | ||||
|                 data={"name": uid, "slug": uid}, | ||||
|             ) | ||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) | ||||
|         self.assertFalse( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, | ||||
|                 context__model__model_name="application", | ||||
|                 context__model__app="authentik_core", | ||||
|                 context__model__name=uid, | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_audit_overwrite_user(self): | ||||
|         """Test audit_overwrite_user context manager""" | ||||
|         uid = generate_id() | ||||
|         new_user = create_test_admin_user() | ||||
|         with audit_overwrite_user(new_user): | ||||
|             self.client.post( | ||||
|                 reverse("authentik_api:application-list"), | ||||
|                 data={"name": uid, "slug": uid}, | ||||
|             ) | ||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, | ||||
|                 context__model__model_name="application", | ||||
|                 context__model__app="authentik_core", | ||||
|                 context__model__name=uid, | ||||
|                 user__username=new_user.username, | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_create_with_api(self): | ||||
|         """Test model creation event (with API token auth)""" | ||||
|         self.client.logout() | ||||
|         token = Token.objects.create(user=self.user, intent=TokenIntents.INTENT_API, expiring=False) | ||||
|         uid = generate_id() | ||||
|         self.client.post( | ||||
|             reverse("authentik_api:application-list"), | ||||
|             data={"name": uid, "slug": uid}, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {token.key}", | ||||
|         ) | ||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) | ||||
|         event = Event.objects.filter( | ||||
|             action=EventAction.MODEL_CREATED, | ||||
|             context__model__model_name="application", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__name=uid, | ||||
|         ).first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "pk": self.user.pk, | ||||
|                 "email": self.user.email, | ||||
|                 "username": self.user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -1,35 +0,0 @@ | ||||
| """authentik event models tests""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
|  | ||||
| from django.db.models import Model | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.core.models import default_token_key | ||||
| from authentik.lib.utils.reflection import get_apps | ||||
|  | ||||
|  | ||||
| class TestModels(TestCase): | ||||
|     """Test Models""" | ||||
|  | ||||
|  | ||||
| def model_tester_factory(test_model: type[Model]) -> Callable: | ||||
|     """Test models' __str__ and __repr__""" | ||||
|  | ||||
|     def tester(self: TestModels): | ||||
|         allowed = 0 | ||||
|         # Token-like objects need to lookup the current tenant to get the default token length | ||||
|         for field in test_model._meta.fields: | ||||
|             if field.default == default_token_key: | ||||
|                 allowed += 1 | ||||
|         with self.assertNumQueries(allowed): | ||||
|             str(test_model()) | ||||
|         with self.assertNumQueries(allowed): | ||||
|             repr(test_model()) | ||||
|  | ||||
|     return tester | ||||
|  | ||||
|  | ||||
| for app in get_apps(): | ||||
|     for model in app.get_models(): | ||||
|         setattr(TestModels, f"test_{app.label}_{model.__name__}", model_tester_factory(model)) | ||||
| @ -47,4 +47,3 @@ class FlowStageBindingViewSet(UsedByMixin, ModelViewSet): | ||||
|     filterset_fields = "__all__" | ||||
|     search_fields = ["stage__name"] | ||||
|     ordering = ["order"] | ||||
|     ordering_fields = ["order", "stage__name"] | ||||
|  | ||||
| @ -7,7 +7,7 @@ from django.utils.translation import gettext as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.fields import BooleanField, CharField, ReadOnlyField | ||||
| from rest_framework.fields import BooleanField, CharField, DictField, ListField, ReadOnlyField | ||||
| from rest_framework.parsers import MultiPartParser | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| @ -19,7 +19,7 @@ from authentik.blueprints.v1.exporter import FlowExporter | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import CacheSerializer, LinkSerializer, PassiveSerializer | ||||
| from authentik.events.logs import LogEventSerializer | ||||
| from authentik.events.utils import sanitize_dict | ||||
| from authentik.flows.api.flows_diagram import FlowDiagram, FlowDiagramSerializer | ||||
| from authentik.flows.exceptions import FlowNonApplicableException | ||||
| from authentik.flows.models import Flow | ||||
| @ -107,7 +107,7 @@ class FlowSetSerializer(FlowSerializer): | ||||
| class FlowImportResultSerializer(PassiveSerializer): | ||||
|     """Logs of an attempted flow import""" | ||||
|  | ||||
|     logs = LogEventSerializer(many=True, read_only=True) | ||||
|     logs = ListField(child=DictField(), read_only=True) | ||||
|     success = BooleanField(read_only=True) | ||||
|  | ||||
|  | ||||
| @ -184,7 +184,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|  | ||||
|         importer = Importer.from_string(file.read().decode()) | ||||
|         valid, logs = importer.validate() | ||||
|         import_response.initial_data["logs"] = [LogEventSerializer(log).data for log in logs] | ||||
|         import_response.initial_data["logs"] = [sanitize_dict(log) for log in logs] | ||||
|         import_response.initial_data["success"] = valid | ||||
|         import_response.is_valid() | ||||
|         if not valid: | ||||
| @ -278,7 +278,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|         }, | ||||
|     ) | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) | ||||
|     def execute(self, request: Request, slug: str): | ||||
|     def execute(self, request: Request, _slug: str): | ||||
|         """Execute flow for current user""" | ||||
|         # Because we pre-plan the flow here, and not in the planner, we need to manually clear | ||||
|         # the history of the inspector | ||||
|  | ||||
| @ -31,9 +31,10 @@ class AuthentikFlowsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Flows" | ||||
|     default = True | ||||
|  | ||||
|     def import_related(self): | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def load_stages(self): | ||||
|         """Ensure all stages are loaded""" | ||||
|         from authentik.flows.models import Stage | ||||
|  | ||||
|         for stage in all_subclasses(Stage): | ||||
|             _ = stage().view | ||||
|         return super().import_related() | ||||
|  | ||||
| @ -59,11 +59,11 @@ class FlowPlan: | ||||
|     markers: list[StageMarker] = field(default_factory=list) | ||||
|  | ||||
|     def append_stage(self, stage: Stage, marker: StageMarker | None = None): | ||||
|         """Append `stage` to the end of the plan, optionally with stage marker""" | ||||
|         """Append `stage` to all stages, optionally with stage marker""" | ||||
|         return self.append(FlowStageBinding(stage=stage), marker) | ||||
|  | ||||
|     def append(self, binding: FlowStageBinding, marker: StageMarker | None = None): | ||||
|         """Append `stage` to the end of the plan, optionally with stage marker""" | ||||
|         """Append `stage` to all stages, optionally with stage marker""" | ||||
|         self.bindings.append(binding) | ||||
|         self.markers.append(marker or StageMarker()) | ||||
|  | ||||
| @ -203,7 +203,6 @@ class FlowPlanner: | ||||
|                 "f(plan): building plan", | ||||
|             ) | ||||
|             plan = self._build_plan(user, request, default_context) | ||||
|             if self.use_cache: | ||||
|             cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT) | ||||
|             if not plan.bindings and not self.allow_empty_flows: | ||||
|                 raise EmptyFlowException() | ||||
|  | ||||
| @ -6,7 +6,6 @@ from rest_framework.test import APITestCase | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.flows.api.stages import StageSerializer, StageViewSet | ||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
| from authentik.policies.models import PolicyBinding | ||||
| from authentik.stages.dummy.models import DummyStage | ||||
| @ -102,21 +101,3 @@ class TestFlowsAPI(APITestCase): | ||||
|             reverse("authentik_api:stage-types"), | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_execute(self): | ||||
|         """Test execute endpoint""" | ||||
|         user = create_test_admin_user() | ||||
|         self.client.force_login(user) | ||||
|  | ||||
|         flow = Flow.objects.create( | ||||
|             name=generate_id(), | ||||
|             slug=generate_id(), | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         FlowStageBinding.objects.create( | ||||
|             target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0 | ||||
|         ) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:flow-execute", kwargs={"slug": flow.slug}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
| @ -24,6 +24,7 @@ from sentry_sdk.hub import Hub | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.brands.utils import cors_allow | ||||
| from authentik.core.models import Application | ||||
| from authentik.events.models import Event, EventAction, cleanse_dict | ||||
| from authentik.flows.apps import HIST_FLOW_EXECUTION_STAGE_TIME | ||||
| @ -155,6 +156,14 @@ class FlowExecutorView(APIView): | ||||
|         return plan | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: | ||||
|         response = self.dispatch_wrapper(request, flow_slug) | ||||
|         origins = [] | ||||
|         if request.brand.origin != "": | ||||
|             origins.append(request.brand.origin) | ||||
|         cors_allow(request, response, *origins) | ||||
|         return response | ||||
|  | ||||
|     def dispatch_wrapper(self, request: HttpRequest, flow_slug: str) -> HttpResponse: | ||||
|         with Hub.current.start_span( | ||||
|             op="authentik.flow.executor.dispatch", description=self.flow.slug | ||||
|         ) as span: | ||||
| @ -450,7 +459,7 @@ class FlowExecutorView(APIView): | ||||
|         return to_stage_response(self.request, challenge_view.get(self.request)) | ||||
|  | ||||
|     def cancel(self): | ||||
|         """Cancel current flow execution""" | ||||
|         """Cancel current execution and return a redirect""" | ||||
|         keys_to_delete = [ | ||||
|             SESSION_KEY_APPLICATION_PRE, | ||||
|             SESSION_KEY_PLAN, | ||||
|  | ||||
| @ -11,7 +11,7 @@ from django.http import HttpRequest, HttpResponseNotFound | ||||
| from django.templatetags.static import static | ||||
| from lxml import etree  # nosec | ||||
| from lxml.etree import Element, SubElement  # nosec | ||||
| from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout | ||||
| from requests.exceptions import RequestException | ||||
|  | ||||
| from authentik.lib.config import get_path_from_dict | ||||
| from authentik.lib.utils.http import get_http_session | ||||
| @ -23,8 +23,6 @@ if TYPE_CHECKING: | ||||
| GRAVATAR_URL = "https://secure.gravatar.com" | ||||
| DEFAULT_AVATAR = static("dist/assets/images/user_default.png") | ||||
| CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/" | ||||
| CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available" | ||||
| GRAVATAR_STATUS_TTL_SECONDS = 60 * 60 * 8  # 8 Hours | ||||
|  | ||||
| SVG_XML_NS = "http://www.w3.org/2000/svg" | ||||
| SVG_NS_MAP = {None: SVG_XML_NS} | ||||
| @ -52,9 +50,6 @@ def avatar_mode_attribute(user: "User", mode: str) -> str | None: | ||||
|  | ||||
| def avatar_mode_gravatar(user: "User", mode: str) -> str | None: | ||||
|     """Gravatar avatars""" | ||||
|     if not cache.get(CACHE_KEY_GRAVATAR_AVAILABLE, True): | ||||
|         return None | ||||
|  | ||||
|     # gravatar uses md5 for their URLs, so md5 can't be avoided | ||||
|     mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest()  # nosec | ||||
|     parameters = [("size", "158"), ("rating", "g"), ("default", "404")] | ||||
| @ -74,8 +69,6 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None: | ||||
|             cache.set(full_key, None) | ||||
|             return None | ||||
|         res.raise_for_status() | ||||
|     except (Timeout, ConnectionError, HTTPError): | ||||
|         cache.set(CACHE_KEY_GRAVATAR_AVAILABLE, False, timeout=GRAVATAR_STATUS_TTL_SECONDS) | ||||
|     except RequestException: | ||||
|         return gravatar_url | ||||
|     cache.set(full_key, gravatar_url) | ||||
|  | ||||
| @ -14,7 +14,7 @@ from pathlib import Path | ||||
| from sys import argv, stderr | ||||
| from time import time | ||||
| from typing import Any | ||||
| from urllib.parse import quote_plus, urlparse | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| import yaml | ||||
| from django.conf import ImproperlyConfigured | ||||
| @ -331,26 +331,6 @@ class ConfigLoader: | ||||
| CONFIG = ConfigLoader() | ||||
|  | ||||
|  | ||||
| def redis_url(db: int) -> str: | ||||
|     """Helper to create a Redis URL for a specific database""" | ||||
|     _redis_protocol_prefix = "redis://" | ||||
|     _redis_tls_requirements = "" | ||||
|     if CONFIG.get_bool("redis.tls", False): | ||||
|         _redis_protocol_prefix = "rediss://" | ||||
|         _redis_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}" | ||||
|         if _redis_ca := CONFIG.get("redis.tls_ca_cert", None): | ||||
|             _redis_tls_requirements += f"&ssl_ca_certs={_redis_ca}" | ||||
|     _redis_url = ( | ||||
|         f"{_redis_protocol_prefix}" | ||||
|         f"{quote_plus(CONFIG.get('redis.username'))}:" | ||||
|         f"{quote_plus(CONFIG.get('redis.password'))}@" | ||||
|         f"{quote_plus(CONFIG.get('redis.host'))}:" | ||||
|         f"{CONFIG.get_int('redis.port')}" | ||||
|         f"/{db}{_redis_tls_requirements}" | ||||
|     ) | ||||
|     return _redis_url | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     if len(argv) < 2:  # noqa: PLR2004 | ||||
|         print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder)) | ||||
|  | ||||
| @ -35,7 +35,6 @@ redis: | ||||
|   password: "" | ||||
|   tls: false | ||||
|   tls_reqs: "none" | ||||
|   tls_ca_cert: null | ||||
|  | ||||
| # broker: | ||||
| #   url: "" | ||||
| @ -53,15 +52,12 @@ cache: | ||||
|  | ||||
| # result_backend: | ||||
| #   url: "" | ||||
| #   transport_options: "" | ||||
|  | ||||
| debug: false | ||||
| remote_debug: false | ||||
|  | ||||
| log_level: info | ||||
|  | ||||
| session_storage: cache | ||||
|  | ||||
| error_reporting: | ||||
|   enabled: false | ||||
|   sentry_dsn: https://151ba72610234c4c97c5bcff4e1cffd8@authentik.error-reporting.a7k.io/4504163677503489 | ||||
| @ -114,6 +110,7 @@ events: | ||||
|     asn: "/geoip/GeoLite2-ASN.mmdb" | ||||
|  | ||||
| cert_discovery_dir: /certs | ||||
| default_token_length: 60 | ||||
|  | ||||
| tenants: | ||||
|   enabled: false | ||||
|  | ||||
| @ -2,11 +2,11 @@ | ||||
|  | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.conf import settings | ||||
| from requests.sessions import PreparedRequest, Session | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -35,6 +35,6 @@ class DebugSession(Session): | ||||
|  | ||||
| def get_http_session() -> Session: | ||||
|     """Get a requests session with common headers""" | ||||
|     session = DebugSession() if CONFIG.get_bool("debug") else Session() | ||||
|     session = DebugSession() if settings.DEBUG else Session() | ||||
|     session.headers["User-Agent"] = authentik_user_agent() | ||||
|     return session | ||||
|  | ||||
| @ -3,14 +3,12 @@ | ||||
| import os | ||||
| from importlib import import_module | ||||
| from pathlib import Path | ||||
| from tempfile import gettempdir | ||||
|  | ||||
| from django.conf import settings | ||||
| from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
| SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" | ||||
|  | ||||
|  | ||||
| def all_subclasses(cls, sort=True): | ||||
|     """Recursively return all subclassess of cls""" | ||||
| @ -57,7 +55,7 @@ def get_env() -> str: | ||||
|         return "dev" | ||||
|     if SERVICE_HOST_ENV_NAME in os.environ: | ||||
|         return "kubernetes" | ||||
|     if (Path(gettempdir()) / "authentik-mode").exists(): | ||||
|     if Path("/tmp/authentik-mode").exists():  # nosec | ||||
|         return "compose" | ||||
|     if "AK_APPLIANCE" in os.environ: | ||||
|         return os.environ["AK_APPLIANCE"] | ||||
|  | ||||
| @ -45,14 +45,14 @@ class AuthentikOutpostConfig(ManagedAppConfig): | ||||
|                 outpost.managed = MANAGED_OUTPOST | ||||
|                 outpost.save() | ||||
|                 return | ||||
|             outpost, created = Outpost.objects.update_or_create( | ||||
|             outpost, updated = Outpost.objects.update_or_create( | ||||
|                 defaults={ | ||||
|                     "type": OutpostType.PROXY, | ||||
|                     "name": MANAGED_OUTPOST_NAME, | ||||
|                 }, | ||||
|                 managed=MANAGED_OUTPOST, | ||||
|             ) | ||||
|             if created: | ||||
|             if updated: | ||||
|                 if KubernetesServiceConnection.objects.exists(): | ||||
|                     outpost.service_connection = KubernetesServiceConnection.objects.first() | ||||
|                 elif DockerServiceConnection.objects.exists(): | ||||
|  | ||||
| @ -3,9 +3,9 @@ | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from structlog.stdlib import get_logger | ||||
| from structlog.testing import capture_logs | ||||
|  | ||||
| from authentik import __version__, get_build_hash | ||||
| from authentik.events.logs import LogEvent, capture_logs | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.outposts.models import ( | ||||
| @ -63,21 +63,21 @@ class BaseController: | ||||
|         """Called by scheduled task to reconcile deployment/service/etc""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def up_with_logs(self) -> list[LogEvent]: | ||||
|     def up_with_logs(self) -> list[str]: | ||||
|         """Call .up() but capture all log output and return it.""" | ||||
|         with capture_logs() as logs: | ||||
|             self.up() | ||||
|         return logs | ||||
|         return [x["event"] for x in logs] | ||||
|  | ||||
|     def down(self): | ||||
|         """Handler to delete everything we've created""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def down_with_logs(self) -> list[LogEvent]: | ||||
|     def down_with_logs(self) -> list[str]: | ||||
|         """Call .down() but capture all log output and return it.""" | ||||
|         with capture_logs() as logs: | ||||
|             self.down() | ||||
|         return logs | ||||
|         return [x["event"] for x in logs] | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self | ||||
|  | ||||
| @ -9,10 +9,10 @@ from kubernetes.client.exceptions import OpenApiException | ||||
| from kubernetes.config.config_exception import ConfigException | ||||
| from kubernetes.config.incluster_config import load_incluster_config | ||||
| from kubernetes.config.kube_config import load_kube_config_from_dict | ||||
| from structlog.testing import capture_logs | ||||
| from urllib3.exceptions import HTTPError | ||||
| from yaml import dump_all | ||||
|  | ||||
| from authentik.events.logs import LogEvent, capture_logs | ||||
| from authentik.outposts.controllers.base import BaseClient, BaseController, ControllerException | ||||
| from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler | ||||
| from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler | ||||
| @ -91,7 +91,7 @@ class KubernetesController(BaseController): | ||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||
|             raise ControllerException(str(exc)) from exc | ||||
|  | ||||
|     def up_with_logs(self) -> list[LogEvent]: | ||||
|     def up_with_logs(self) -> list[str]: | ||||
|         try: | ||||
|             all_logs = [] | ||||
|             for reconcile_key in self.reconcile_order: | ||||
| @ -104,9 +104,7 @@ class KubernetesController(BaseController): | ||||
|                         continue | ||||
|                     reconciler = reconciler_cls(self) | ||||
|                     reconciler.up() | ||||
|                 for log in logs: | ||||
|                     log.logger = reconcile_key.title() | ||||
|                 all_logs.extend(logs) | ||||
|                 all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs] | ||||
|             return all_logs | ||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||
|             raise ControllerException(str(exc)) from exc | ||||
| @ -124,7 +122,7 @@ class KubernetesController(BaseController): | ||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||
|             raise ControllerException(str(exc)) from exc | ||||
|  | ||||
|     def down_with_logs(self) -> list[LogEvent]: | ||||
|     def down_with_logs(self) -> list[str]: | ||||
|         try: | ||||
|             all_logs = [] | ||||
|             for reconcile_key in self.reconcile_order: | ||||
| @ -137,9 +135,7 @@ class KubernetesController(BaseController): | ||||
|                         continue | ||||
|                     reconciler = reconciler_cls(self) | ||||
|                     reconciler.down() | ||||
|                 for log in logs: | ||||
|                     log.logger = reconcile_key.title() | ||||
|                 all_logs.extend(logs) | ||||
|                 all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs] | ||||
|             return all_logs | ||||
|         except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: | ||||
|             raise ControllerException(str(exc)) from exc | ||||
|  | ||||
| @ -149,8 +149,10 @@ def outpost_controller( | ||||
|         if not controller_type: | ||||
|             return | ||||
|         with controller_type(outpost, outpost.service_connection) as controller: | ||||
|             LOGGER.debug("---------------Outpost Controller logs starting----------------") | ||||
|             logs = getattr(controller, f"{action}_with_logs")() | ||||
|             LOGGER.debug("---------------Outpost Controller logs starting----------------") | ||||
|             for log in logs: | ||||
|                 LOGGER.debug(log) | ||||
|             LOGGER.debug("-----------------Outpost Controller logs end-------------------") | ||||
|     except (ControllerException, ServiceConnectionInvalid) as exc: | ||||
|         self.set_error(exc) | ||||
|  | ||||
| @ -1,11 +1,10 @@ | ||||
| """Serializer for policy execution""" | ||||
|  | ||||
| from rest_framework.fields import BooleanField, CharField, ListField | ||||
| from rest_framework.fields import BooleanField, CharField, DictField, ListField | ||||
| from rest_framework.relations import PrimaryKeyRelatedField | ||||
|  | ||||
| from authentik.core.api.utils import JSONDictField, PassiveSerializer | ||||
| from authentik.core.models import User | ||||
| from authentik.events.logs import LogEventSerializer | ||||
|  | ||||
|  | ||||
| class PolicyTestSerializer(PassiveSerializer): | ||||
| @ -20,4 +19,4 @@ class PolicyTestResultSerializer(PassiveSerializer): | ||||
|  | ||||
|     passing = BooleanField() | ||||
|     messages = ListField(child=CharField(), read_only=True) | ||||
|     log_messages = LogEventSerializer(many=True, read_only=True) | ||||
|     log_messages = ListField(child=DictField(), read_only=True) | ||||
|  | ||||
| @ -11,11 +11,12 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer, SerializerMethodField | ||||
| from rest_framework.viewsets import GenericViewSet | ||||
| from structlog.stdlib import get_logger | ||||
| from structlog.testing import capture_logs | ||||
|  | ||||
| from authentik.core.api.applications import user_app_cache_key | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer | ||||
| from authentik.events.logs import LogEventSerializer, capture_logs | ||||
| from authentik.events.utils import sanitize_dict | ||||
| from authentik.lib.utils.reflection import all_subclasses | ||||
| from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer | ||||
| from authentik.policies.models import Policy, PolicyBinding | ||||
| @ -165,9 +166,9 @@ class PolicyViewSet( | ||||
|             result = proc.execute() | ||||
|         log_messages = [] | ||||
|         for log in logs: | ||||
|             if log.attributes.get("process", "") == "PolicyProcess": | ||||
|             if log.get("process", "") == "PolicyProcess": | ||||
|                 continue | ||||
|             log_messages.append(LogEventSerializer(log).data) | ||||
|             log_messages.append(sanitize_dict(log)) | ||||
|         result.log_messages = log_messages | ||||
|         response = PolicyTestResultSerializer(result) | ||||
|         return Response(response.data) | ||||
|  | ||||
| @ -39,7 +39,6 @@ class Migration(migrations.Migration): | ||||
|                     ("authentik.sources.oauth", "authentik Sources.OAuth"), | ||||
|                     ("authentik.sources.plex", "authentik Sources.Plex"), | ||||
|                     ("authentik.sources.saml", "authentik Sources.SAML"), | ||||
|                     ("authentik.sources.scim", "authentik Sources.SCIM"), | ||||
|                     ("authentik.stages.authenticator_duo", "authentik Stages.Authenticator.Duo"), | ||||
|                     ("authentik.stages.authenticator_sms", "authentik Stages.Authenticator.SMS"), | ||||
|                     ( | ||||
|  | ||||
| @ -13,7 +13,6 @@ from authentik.events.context_processors.base import get_context_processors | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from authentik.core.models import User | ||||
|     from authentik.events.logs import LogEvent | ||||
|     from authentik.policies.models import PolicyBinding | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| @ -75,7 +74,7 @@ class PolicyResult: | ||||
|     source_binding: PolicyBinding | None | ||||
|     source_results: list[PolicyResult] | None | ||||
|  | ||||
|     log_messages: list[LogEvent] | None | ||||
|     log_messages: list[dict] | None | ||||
|  | ||||
|     def __init__(self, passing: bool, *messages: str): | ||||
|         self.passing = passing | ||||
|  | ||||
| @ -1,9 +1,9 @@ | ||||
| """authentik oauth provider app config""" | ||||
|  | ||||
| from django.apps import AppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
|  | ||||
| class AuthentikProviderOAuth2Config(AppConfig): | ||||
| class AuthentikProviderOAuth2Config(ManagedAppConfig): | ||||
|     """authentik oauth provider app config""" | ||||
|  | ||||
|     name = "authentik.providers.oauth2" | ||||
| @ -13,3 +13,4 @@ class AuthentikProviderOAuth2Config(AppConfig): | ||||
|         "authentik.providers.oauth2.urls_root": "", | ||||
|         "authentik.providers.oauth2.urls": "application/o/", | ||||
|     } | ||||
|     default = True | ||||
|  | ||||
| @ -8,7 +8,6 @@ from django.http import HttpRequest | ||||
| from django.utils import timezone | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| from authentik.core.models import default_token_duration | ||||
| from authentik.events.signals import get_login_event | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.constants import ( | ||||
| @ -88,9 +87,7 @@ class IDToken: | ||||
|     ) -> "IDToken": | ||||
|         """Create ID Token""" | ||||
|         id_token = IDToken(provider, token, **kwargs) | ||||
|         id_token.exp = int( | ||||
|             (token.expires if token.expires is not None else default_token_duration()).timestamp() | ||||
|         ) | ||||
|         id_token.exp = int(token.expires.timestamp()) | ||||
|         id_token.iss = provider.get_issuer(request) | ||||
|         id_token.aud = provider.client_id | ||||
|         id_token.claims = {} | ||||
|  | ||||
| @ -1,36 +0,0 @@ | ||||
| # Generated by Django 5.0.2 on 2024-02-29 10:15 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ( | ||||
|             "authentik_providers_oauth2", | ||||
|             "0017_accesstoken_session_id_authorizationcode_session_id_and_more", | ||||
|         ), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterField( | ||||
|             model_name="accesstoken", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="authorizationcode", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="devicetoken", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="refreshtoken", | ||||
|             name="expires", | ||||
|             field=models.DateTimeField(default=None, null=True), | ||||
|         ), | ||||
|     ] | ||||
| @ -326,7 +326,7 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel): | ||||
|         verbose_name_plural = _("Authorization Codes") | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"Authorization code for {self.provider_id} for user {self.user_id}" | ||||
|         return f"Authorization code for {self.provider} for user {self.user}" | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> Serializer: | ||||
| @ -356,7 +356,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel): | ||||
|         verbose_name_plural = _("OAuth2 Access Tokens") | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"Access Token for {self.provider_id} for user {self.user_id}" | ||||
|         return f"Access Token for {self.provider} for user {self.user}" | ||||
|  | ||||
|     @property | ||||
|     def id_token(self) -> IDToken: | ||||
| @ -399,7 +399,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): | ||||
|         verbose_name_plural = _("OAuth2 Refresh Tokens") | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"Refresh Token for {self.provider_id} for user {self.user_id}" | ||||
|         return f"Refresh Token for {self.provider} for user {self.user}" | ||||
|  | ||||
|     @property | ||||
|     def id_token(self) -> IDToken: | ||||
| @ -443,4 +443,4 @@ class DeviceToken(ExpiringModel): | ||||
|         verbose_name_plural = _("Device Tokens") | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"Device Token for {self.provider_id}" | ||||
|         return f"Device Token for {self.provider}" | ||||
|  | ||||
							
								
								
									
										15
									
								
								authentik/providers/oauth2/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								authentik/providers/oauth2/signals.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | ||||
| from hashlib import sha256 | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.providers.oauth2.models import AccessToken | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User, **_): | ||||
|     """Revoke access tokens upon user logout""" | ||||
|     hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest() | ||||
|     AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete() | ||||
| @ -4,10 +4,9 @@ from urllib.parse import urlencode | ||||
|  | ||||
| from django.urls import reverse | ||||
|  | ||||
| from authentik.core.models import Application, Group | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.models import PolicyBinding | ||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||
| @ -78,23 +77,3 @@ class TesOAuth2DeviceInit(OAuthTestCase): | ||||
|             + "?" | ||||
|             + urlencode({QS_KEY_CODE: token.user_code}), | ||||
|         ) | ||||
|  | ||||
|     def test_device_init_denied(self): | ||||
|         """Test device init""" | ||||
|         group = Group.objects.create(name="foo") | ||||
|         PolicyBinding.objects.create( | ||||
|             group=group, | ||||
|             target=self.application, | ||||
|             order=0, | ||||
|         ) | ||||
|         token = DeviceToken.objects.create( | ||||
|             user_code="foo", | ||||
|             provider=self.provider, | ||||
|         ) | ||||
|         res = self.client.get( | ||||
|             reverse("authentik_providers_oauth2_root:device-login") | ||||
|             + "?" | ||||
|             + urlencode({QS_KEY_CODE: token.user_code}) | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         self.assertIn(b"Permission denied", res.content) | ||||
|  | ||||
| @ -10,7 +10,6 @@ from jwt import PyJWKSet | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_cert, create_test_flow | ||||
| from authentik.crypto.builder import PrivateKeyAlg | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.oauth2.models import OAuth2Provider | ||||
| @ -83,7 +82,7 @@ class TestJWKS(OAuthTestCase): | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             signing_key=create_test_cert(PrivateKeyAlg.ECDSA), | ||||
|             signing_key=create_test_cert(use_ec_private_key=True), | ||||
|         ) | ||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||
|         response = self.client.get( | ||||
|  | ||||
| @ -208,7 +208,7 @@ class TestToken(OAuthTestCase): | ||||
|                 "token_type": TOKEN_TYPE, | ||||
|                 "expires_in": 3600, | ||||
|                 "id_token": provider.encode( | ||||
|                     access.id_token.to_dict(), | ||||
|                     refresh.id_token.to_dict(), | ||||
|                 ), | ||||
|             }, | ||||
|         ) | ||||
| @ -267,7 +267,7 @@ class TestToken(OAuthTestCase): | ||||
|                 "token_type": TOKEN_TYPE, | ||||
|                 "expires_in": 3600, | ||||
|                 "id_token": provider.encode( | ||||
|                     access.id_token.to_dict(), | ||||
|                     refresh.id_token.to_dict(), | ||||
|                 ), | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -4,11 +4,9 @@ import re | ||||
| from base64 import b64decode | ||||
| from binascii import Error | ||||
| from typing import Any | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse, JsonResponse | ||||
| from django.http.response import HttpResponseRedirect | ||||
| from django.utils.cache import patch_vary_headers | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.middleware import CTX_AUTH_VIA, KEY_USER | ||||
| @ -30,49 +28,6 @@ class TokenResponse(JsonResponse): | ||||
|         self["Pragma"] = "no-cache" | ||||
|  | ||||
|  | ||||
| def cors_allow(request: HttpRequest, response: HttpResponse, *allowed_origins: str): | ||||
|     """Add headers to permit CORS requests from allowed_origins, with or without credentials, | ||||
|     with any headers.""" | ||||
|     origin = request.META.get("HTTP_ORIGIN") | ||||
|     if not origin: | ||||
|         return response | ||||
|  | ||||
|     # OPTIONS requests don't have an authorization header -> hence | ||||
|     # we can't extract the provider this request is for | ||||
|     # so for options requests we allow the calling origin without checking | ||||
|     allowed = request.method == "OPTIONS" | ||||
|     received_origin = urlparse(origin) | ||||
|     for allowed_origin in allowed_origins: | ||||
|         url = urlparse(allowed_origin) | ||||
|         if ( | ||||
|             received_origin.scheme == url.scheme | ||||
|             and received_origin.hostname == url.hostname | ||||
|             and received_origin.port == url.port | ||||
|         ): | ||||
|             allowed = True | ||||
|     if not allowed: | ||||
|         LOGGER.warning( | ||||
|             "CORS: Origin is not an allowed origin", | ||||
|             requested=received_origin, | ||||
|             allowed=allowed_origins, | ||||
|         ) | ||||
|         return response | ||||
|  | ||||
|     # From the CORS spec: The string "*" cannot be used for a resource that supports credentials. | ||||
|     response["Access-Control-Allow-Origin"] = origin | ||||
|     patch_vary_headers(response, ["Origin"]) | ||||
|     response["Access-Control-Allow-Credentials"] = "true" | ||||
|  | ||||
|     if request.method == "OPTIONS": | ||||
|         if "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" in request.META: | ||||
|             response["Access-Control-Allow-Headers"] = request.META[ | ||||
|                 "HTTP_ACCESS_CONTROL_REQUEST_HEADERS" | ||||
|             ] | ||||
|         response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" | ||||
|  | ||||
|     return response | ||||
|  | ||||
|  | ||||
| def extract_access_token(request: HttpRequest) -> str | None: | ||||
|     """ | ||||
|     Get the access token using Authorization Request Header Field method. | ||||
|  | ||||
| @ -11,11 +11,10 @@ from django.views.decorators.csrf import csrf_exempt | ||||
| from rest_framework.throttling import AnonRateThrottle | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -38,9 +37,7 @@ class DeviceView(View): | ||||
|         ).first() | ||||
|         if not provider: | ||||
|             return HttpResponseBadRequest() | ||||
|         try: | ||||
|             _ = provider.application | ||||
|         except Application.DoesNotExist: | ||||
|         if not get_application(provider): | ||||
|             return HttpResponseBadRequest() | ||||
|         self.provider = provider | ||||
|         self.client_id = client_id | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	