Compare commits
	
		
			245 Commits
		
	
	
		
			version/20
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e9910732bc | |||
| 246dd4b062 | |||
| 4425f8d183 | |||
| c410bb8c36 | |||
| 44f62a4773 | |||
| b6ff04694f | |||
| d4ce0e8e41 | |||
| 362d72da8c | |||
| 88d0f8d8a8 | |||
| 61097b9400 | |||
| 7a73ddfb60 | |||
| d66f13c249 | |||
| 8cc3cb6a42 | |||
| 4c5537ddfe | |||
| a95779157d | |||
| 70256727fd | |||
| ac6afb2b82 | |||
| 2ea7bd86e8 | |||
| 95bce9c9e7 | |||
| 71a22c2a34 | |||
| f3eb85877d | |||
| 273f5211a0 | |||
| db06428ab9 | |||
| 109d8e48d4 | |||
| 2ca115285c | |||
| f5459645a5 | |||
| 14c159500d | |||
| 03da87991f | |||
| e38ee9c580 | |||
| 3bf53b2db1 | |||
| f33190caa5 | |||
| 741822424a | |||
| 0ca6fbb224 | |||
| f72b652b24 | |||
| 0a2c1eb419 | |||
| eb9593a847 | |||
| 7c71c52791 | |||
| 59493c02c4 | |||
| 83089b47d3 | |||
| 103e723d8c | |||
| 7d6e88061f | |||
| f8aab40e3e | |||
| 5123bc1316 | |||
| 30e8408e85 | |||
| bb34474101 | |||
| a105760123 | |||
| f410a77010 | |||
| 6ff8fdcc49 | |||
| 50ca3dc772 | |||
| 2a09fc0ae2 | |||
| fbb6756488 | |||
| f45fb2eac0 | |||
| 7b8cde17e6 | |||
| 186634fc67 | |||
| c84b1b7997 | |||
| 6e83467481 | |||
| 72db17f23b | |||
| ee4e176039 | |||
| e18e681c2b | |||
| 10fe67e08d | |||
| fc1db83be7 | |||
| 3740e65906 | |||
| 30386cd899 | |||
| 64a10e9a46 | |||
| 77d6242cce | |||
| 9a86dcaec3 | |||
| 0b00768b84 | |||
| d162c79373 | |||
| 05db352a0f | |||
| 5bf3d7fe02 | |||
| 1ae1cbebf4 | |||
| 8c16dfc478 | |||
| c6a3286e4c | |||
| 44cfd7e5b0 | |||
| 210d4c5058 | |||
| 6b39d616b1 | |||
| 32ace1bece | |||
| 54f893b84f | |||
| b5685ec072 | |||
| 5854833240 | |||
| 4b2437a6f1 | |||
| 2981ac7b10 | |||
| 59a51c859a | |||
| 47bab6c182 | |||
| 4e6714fffe | |||
| aa6b595545 | |||
| 0131b1f6cc | |||
| 9f53c359dd | |||
| 28e4dba3e8 | |||
| 2afd46e1df | |||
| f5991b19be | |||
| 5cc75cb25c | |||
| 68c1df2d39 | |||
| c83724f45c | |||
| 5f91c150df | |||
| 0bfe999442 | |||
| 58440b16c4 | |||
| 57757a2ff5 | |||
| 2993f506a7 | |||
| e4841d54a1 | |||
| 4f05dcec89 | |||
| ede6bcd31e | |||
| 728c8e994d | |||
| 5290b64415 | |||
| fec6de1ba2 | |||
| 69678dcfa6 | |||
| 4911a243ff | |||
| 70316b37da | |||
| 307cb94e3b | |||
| ace53a8fa5 | |||
| 0544dc3f83 | |||
| 708ff300a3 | |||
| 4e63f0f215 | |||
| 141481df3a | |||
| 29241cc287 | |||
| e81e97d404 | |||
| a5182e5c24 | |||
| cf5ff6e160 | |||
| f2b3a2ec91 | |||
| 69780c67a9 | |||
| ac9cf590bc | |||
| cb6edcb198 | |||
| 8eecc28c3c | |||
| 10b16bc36a | |||
| 2fe88cfea9 | |||
| caab396b56 | |||
| 5f0f4284a2 | |||
| c11be2284d | |||
| aa321196d7 | |||
| ff03db61a8 | |||
| f3b3ce6572 | |||
| 09b02e1aec | |||
| 451a9aaf01 | |||
| eaee7cb562 | |||
| a010c91a52 | |||
| 709194330f | |||
| 5914bbf173 | |||
| 5e9166f859 | |||
| 35b8ef6592 | |||
| 772a939f17 | |||
| 24971801cf | |||
| 43aebe8cb2 | |||
| 19cfc87c84 | |||
| f920f183c8 | |||
| 97f979c81e | |||
| e61411d396 | |||
| c4f985f542 | |||
| 302dee7ab2 | |||
| 83c12ad483 | |||
| 4224fd5c6f | |||
| 597ce1eb42 | |||
| 5ef385f0bb | |||
| cda4be3d47 | |||
| 8cdf22fc94 | |||
| 6efc7578ef | |||
| 4e2457560d | |||
| 2ddf122d27 | |||
| a24651437a | |||
| 30bb7acb17 | |||
| 7859145138 | |||
| 8a8aafec81 | |||
| deebdf2bcc | |||
| 4982c4abcb | |||
| 1486f90077 | |||
| f4988bc45e | |||
| 8abc9cc031 | |||
| 534689895c | |||
| 8a0dd6be24 | |||
| 65d2eed82d | |||
| e450e7b107 | |||
| 552ddda909 | |||
| bafeff7306 | |||
| 6791436302 | |||
| 7eda794070 | |||
| e3129c1067 | |||
| ff481ba6e7 | |||
| a106bad2db | |||
| 3a1c311d02 | |||
| 6465333f4f | |||
| b761659227 | |||
| 9321c355f8 | |||
| 86c8e79ea1 | |||
| 8916b1f8ab | |||
| 41fcf2aba6 | |||
| 87e72b08a9 | |||
| b2fcd42e3c | |||
| fc1b47a80f | |||
| af14e3502e | |||
| a2faa5ceb5 | |||
| 63a19a1381 | |||
| b472dcb7e7 | |||
| 6303909031 | |||
| 4bdc06865b | |||
| 2ee48cd039 | |||
| 893d5f452b | |||
| 340a9bc8ee | |||
| cb3d9f83f1 | |||
| 4ba55aa8e9 | |||
| bab6f501ec | |||
| 7327939684 | |||
| ffb0135f06 | |||
| ee0ddc3d17 | |||
| 5dd979d66c | |||
| a9bd34f3c5 | |||
| db316b59c5 | |||
| 6209714f87 | |||
| 1ed2bddba7 | |||
| 26b35c9b7b | |||
| 86a9271f75 | |||
| 402ed9bd20 | |||
| 68a0684569 | |||
| bd2e453218 | |||
| 1f31c63e57 | |||
| 480410efa2 | |||
| e9bfee52ed | |||
| 326b574d54 | |||
| 0a7abcf2ad | |||
| 9e5019881e | |||
| 8071750681 | |||
| f2f0931904 | |||
| a91204e5b9 | |||
| b14c22cbff | |||
| b3e40c6aed | |||
| 873aa4bb22 | |||
| c1ea78c422 | |||
| 3c8bbc2621 | |||
| 42a9979d91 | |||
| b7f94df4d9 | |||
| 4143d3fe28 | |||
| f95c06b76f | |||
| e3e9178ccc | |||
| b694816e7b | |||
| e046000f36 | |||
| edb5caae9b | |||
| 02d27651f3 | |||
| 44cd4d847d | |||
| 472256794d | |||
| cbb6887983 | |||
| 317e9ec605 | |||
| ada2a16412 | |||
| 61f6b0f122 | |||
| 6a3f7e45cf | |||
| 2b78c4ba86 | |||
| 680ef641fb | |||
| 6c23fc4b2b | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2021.12.1-rc2 | current_version = 2021.12.2 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)\-?(?P<release>.*) | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)\-?(?P<release>.*) | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/stale.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/stale.yml
									
									
									
									
										vendored
									
									
								
							| @ -7,6 +7,7 @@ exemptLabels: | |||||||
|   - pinned |   - pinned | ||||||
|   - security |   - security | ||||||
|   - pr_wanted |   - pr_wanted | ||||||
|  |   - enhancement/confirmed | ||||||
| # Comment to post when marking an issue as stale. Set to `false` to disable | # Comment to post when marking an issue as stale. Set to `false` to disable | ||||||
| markComment: > | markComment: > | ||||||
|   This issue has been automatically marked as stale because it has not had |   This issue has been automatically marked as stale because it has not had | ||||||
|  | |||||||
							
								
								
									
										72
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										72
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -89,9 +89,11 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|           # Copy current, latest config to local |           # Copy current, latest config to local | ||||||
|           cp authentik/lib/default.yml local.env.yml |           cp authentik/lib/default.yml local.env.yml | ||||||
|  |           cp -R .github .. | ||||||
|  |           cp -R scripts .. | ||||||
|           git checkout $(git describe --abbrev=0 --match 'version/*') |           git checkout $(git describe --abbrev=0 --match 'version/*') | ||||||
|           git checkout $GITHUB_HEAD_REF -- .github |           rm -rf .github/ scripts/ | ||||||
|           git checkout $GITHUB_HEAD_REF -- scripts |           mv ../.github ../scripts . | ||||||
|       - name: prepare |       - name: prepare | ||||||
|         env: |         env: | ||||||
|           INSTALL: ${{ steps.cache-pipenv.outputs.cache-hit }} |           INSTALL: ${{ steps.cache-pipenv.outputs.cache-hit }} | ||||||
| @ -105,6 +107,7 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|           set -x |           set -x | ||||||
|           git fetch |           git fetch | ||||||
|  |           git reset --hard HEAD | ||||||
|           git checkout $GITHUB_HEAD_REF |           git checkout $GITHUB_HEAD_REF | ||||||
|           pipenv sync --dev |           pipenv sync --dev | ||||||
|       - name: prepare |       - name: prepare | ||||||
| @ -173,7 +176,7 @@ jobs: | |||||||
|           testspace [integration]unittest.xml --link=codecov |           testspace [integration]unittest.xml --link=codecov | ||||||
|       - if: ${{ always() }} |       - if: ${{ always() }} | ||||||
|         uses: codecov/codecov-action@v2 |         uses: codecov/codecov-action@v2 | ||||||
|   test-e2e: |   test-e2e-provider: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v2 | ||||||
| @ -212,22 +215,75 @@ jobs: | |||||||
|           npm run build |           npm run build | ||||||
|       - name: run e2e |       - name: run e2e | ||||||
|         run: | |         run: | | ||||||
|           pipenv run make test-e2e |           pipenv run make test-e2e-provider | ||||||
|           pipenv run coverage xml |           pipenv run coverage xml | ||||||
|       - name: run testspace |       - name: run testspace | ||||||
|         if: ${{ always() }} |         if: ${{ always() }} | ||||||
|         run: | |         run: | | ||||||
|           testspace [e2e]unittest.xml --link=codecov |           testspace [e2e-provider]unittest.xml --link=codecov | ||||||
|       - if: ${{ always() }} |       - if: ${{ always() }} | ||||||
|         uses: codecov/codecov-action@v2 |         uses: codecov/codecov-action@v2 | ||||||
|   build: |   test-e2e-rest: | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v2 | ||||||
|  |       - uses: actions/setup-python@v2 | ||||||
|  |         with: | ||||||
|  |           python-version: '3.9' | ||||||
|  |       - uses: actions/setup-node@v2 | ||||||
|  |         with: | ||||||
|  |           node-version: '16' | ||||||
|  |           cache: 'npm' | ||||||
|  |           cache-dependency-path: web/package-lock.json | ||||||
|  |       - uses: testspace-com/setup-testspace@v1 | ||||||
|  |         with: | ||||||
|  |           domain: ${{github.repository_owner}} | ||||||
|  |       - id: cache-pipenv | ||||||
|  |         uses: actions/cache@v2.1.7 | ||||||
|  |         with: | ||||||
|  |           path: ~/.local/share/virtualenvs | ||||||
|  |           key: ${{ runner.os }}-pipenv-v2-${{ hashFiles('**/Pipfile.lock') }} | ||||||
|  |       - name: prepare | ||||||
|  |         env: | ||||||
|  |           INSTALL: ${{ steps.cache-pipenv.outputs.cache-hit }} | ||||||
|  |         run: | | ||||||
|  |           scripts/ci_prepare.sh | ||||||
|  |           docker-compose -f tests/e2e/docker-compose.yml up -d | ||||||
|  |       - id: cache-web | ||||||
|  |         uses: actions/cache@v2.1.7 | ||||||
|  |         with: | ||||||
|  |           path: web/dist | ||||||
|  |           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/**') }} | ||||||
|  |       - name: prepare web ui | ||||||
|  |         if: steps.cache-web.outputs.cache-hit != 'true' | ||||||
|  |         run: | | ||||||
|  |           cd web | ||||||
|  |           npm i | ||||||
|  |           npm run build | ||||||
|  |       - name: run e2e | ||||||
|  |         run: | | ||||||
|  |           pipenv run make test-e2e-rest | ||||||
|  |           pipenv run coverage xml | ||||||
|  |       - name: run testspace | ||||||
|  |         if: ${{ always() }} | ||||||
|  |         run: | | ||||||
|  |           testspace [e2e-rest]unittest.xml --link=codecov | ||||||
|  |       - if: ${{ always() }} | ||||||
|  |         uses: codecov/codecov-action@v2 | ||||||
|  |   ci-core-mark: | ||||||
|     needs: |     needs: | ||||||
|       - lint |       - lint | ||||||
|       - test-migrations |       - test-migrations | ||||||
|       - test-migrations-from-stable |       - test-migrations-from-stable | ||||||
|       - test-unittest |       - test-unittest | ||||||
|       - test-integration |       - test-integration | ||||||
|       - test-e2e |       - test-e2e-rest | ||||||
|  |       - test-e2e-provider | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - run: echo mark | ||||||
|  |   build: | ||||||
|  |     needs: ci-core-mark | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     timeout-minutes: 120 |     timeout-minutes: 120 | ||||||
|     strategy: |     strategy: | ||||||
| @ -244,7 +300,7 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |         env: | ||||||
|           DOCKER_USERNAME: ${{ secrets.HARBOR_USERNAME }} |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|         run: | |         run: | | ||||||
|           python ./scripts/gh_env.py |           python ./scripts/gh_env.py | ||||||
|       - name: Login to Container Registry |       - name: Login to Container Registry | ||||||
|  | |||||||
							
								
								
									
										50
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										50
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -17,7 +17,7 @@ jobs: | |||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v2 | ||||||
|       - uses: actions/setup-go@v2 |       - uses: actions/setup-go@v2 | ||||||
|         with: |         with: | ||||||
|           go-version: '^1.16.3' |           go-version: "^1.17" | ||||||
|       - name: Run linter |       - name: Run linter | ||||||
|         run: | |         run: | | ||||||
|           # Create folder structure for go embeds |           # Create folder structure for go embeds | ||||||
| @ -30,10 +30,16 @@ jobs: | |||||||
|             -w /app \ |             -w /app \ | ||||||
|             golangci/golangci-lint:v1.39.0 \ |             golangci/golangci-lint:v1.39.0 \ | ||||||
|             golangci-lint run -v --timeout 200s |             golangci-lint run -v --timeout 200s | ||||||
|  |   ci-outpost-mark: | ||||||
|  |     needs: | ||||||
|  |       - lint-golint | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - run: echo mark | ||||||
|   build: |   build: | ||||||
|     timeout-minutes: 120 |     timeout-minutes: 120 | ||||||
|     needs: |     needs: | ||||||
|       - lint-golint |       - ci-outpost-mark | ||||||
|     strategy: |     strategy: | ||||||
|       fail-fast: false |       fail-fast: false | ||||||
|       matrix: |       matrix: | ||||||
| @ -52,7 +58,7 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         id: ev |         id: ev | ||||||
|         env: |         env: | ||||||
|           DOCKER_USERNAME: ${{ secrets.HARBOR_USERNAME }} |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|         run: | |         run: | | ||||||
|           python ./scripts/gh_env.py |           python ./scripts/gh_env.py | ||||||
|       - name: Login to Container Registry |       - name: Login to Container Registry | ||||||
| @ -74,3 +80,41 @@ jobs: | |||||||
|           build-args: | |           build-args: | | ||||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} |             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||||
|           platforms: ${{ matrix.arch }} |           platforms: ${{ matrix.arch }} | ||||||
|  |   build-outpost-binary: | ||||||
|  |     timeout-minutes: 120 | ||||||
|  |     needs: | ||||||
|  |       - ci-outpost-mark | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     strategy: | ||||||
|  |       fail-fast: false | ||||||
|  |       matrix: | ||||||
|  |         type: | ||||||
|  |           - proxy | ||||||
|  |           - ldap | ||||||
|  |         goos: [linux] | ||||||
|  |         goarch: [amd64, arm64] | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v2 | ||||||
|  |       - uses: actions/setup-go@v2 | ||||||
|  |         with: | ||||||
|  |           go-version: "^1.17" | ||||||
|  |       - uses: actions/setup-node@v2 | ||||||
|  |         with: | ||||||
|  |           node-version: '16' | ||||||
|  |           cache: 'npm' | ||||||
|  |           cache-dependency-path: web/package-lock.json | ||||||
|  |       - name: Build web | ||||||
|  |         run: | | ||||||
|  |           cd web | ||||||
|  |           npm install | ||||||
|  |           npm run build-proxy | ||||||
|  |       - name: Build outpost | ||||||
|  |         run: | | ||||||
|  |           set -x | ||||||
|  |           export GOOS=${{ matrix.goos }} | ||||||
|  |           export GOARCH=${{ matrix.goarch }} | ||||||
|  |           go build -tags=outpost_static_embed -v -o ./authentik-outpost-${{ matrix.type }}_${{ matrix.goos }}_${{ matrix.goarch }} ./cmd/${{ matrix.type }} | ||||||
|  |       - uses: actions/upload-artifact@v2 | ||||||
|  |         with: | ||||||
|  |           name: authentik-outpost-${{ matrix.type }}_${{ matrix.goos }}_${{ matrix.goarch }} | ||||||
|  |           path: ./authentik-outpost-${{ matrix.type }}_${{ matrix.goos }}_${{ matrix.goarch }} | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							| @ -65,12 +65,18 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|           cd web |           cd web | ||||||
|           npm run lit-analyse |           npm run lit-analyse | ||||||
|   build: |   ci-web-mark: | ||||||
|     needs: |     needs: | ||||||
|       - lint-eslint |       - lint-eslint | ||||||
|       - lint-prettier |       - lint-prettier | ||||||
|       - lint-lit-analyse |       - lint-lit-analyse | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - run: echo mark | ||||||
|  |   build: | ||||||
|  |     needs: | ||||||
|  |       - ci-web-mark | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v2 | ||||||
|       - uses: actions/setup-node@v2 |       - uses: actions/setup-node@v2 | ||||||
|  | |||||||
							
								
								
									
										69
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										69
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -30,14 +30,14 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           push: ${{ github.event_name == 'release' }} |           push: ${{ github.event_name == 'release' }} | ||||||
|           tags: | |           tags: | | ||||||
|             beryju/authentik:2021.12.1-rc2, |             beryju/authentik:2021.12.2, | ||||||
|             beryju/authentik:latest, |             beryju/authentik:latest, | ||||||
|             ghcr.io/goauthentik/server:2021.12.1-rc2, |             ghcr.io/goauthentik/server:2021.12.2, | ||||||
|             ghcr.io/goauthentik/server:latest |             ghcr.io/goauthentik/server:latest | ||||||
|           platforms: linux/amd64,linux/arm64 |           platforms: linux/amd64,linux/arm64 | ||||||
|           context: . |           context: . | ||||||
|       - name: Building Docker Image (stable) |       - name: Building Docker Image (stable) | ||||||
|         if: ${{ github.event_name == 'release' && !contains('2021.12.1-rc2', 'rc') }} |         if: ${{ github.event_name == 'release' && !contains('2021.12.2', 'rc') }} | ||||||
|         run: | |         run: | | ||||||
|           docker pull beryju/authentik:latest |           docker pull beryju/authentik:latest | ||||||
|           docker tag beryju/authentik:latest beryju/authentik:stable |           docker tag beryju/authentik:latest beryju/authentik:stable | ||||||
| @ -57,7 +57,7 @@ jobs: | |||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v2 | ||||||
|       - uses: actions/setup-go@v2 |       - uses: actions/setup-go@v2 | ||||||
|         with: |         with: | ||||||
|           go-version: "^1.15" |           go-version: "^1.17" | ||||||
|       - name: Set up QEMU |       - name: Set up QEMU | ||||||
|         uses: docker/setup-qemu-action@v1.2.0 |         uses: docker/setup-qemu-action@v1.2.0 | ||||||
|       - name: Set up Docker Buildx |       - name: Set up Docker Buildx | ||||||
| @ -78,14 +78,14 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           push: ${{ github.event_name == 'release' }} |           push: ${{ github.event_name == 'release' }} | ||||||
|           tags: | |           tags: | | ||||||
|             beryju/authentik-${{ matrix.type }}:2021.12.1-rc2, |             beryju/authentik-${{ matrix.type }}:2021.12.2, | ||||||
|             beryju/authentik-${{ matrix.type }}:latest, |             beryju/authentik-${{ matrix.type }}:latest, | ||||||
|             ghcr.io/goauthentik/${{ matrix.type }}:2021.12.1-rc2, |             ghcr.io/goauthentik/${{ matrix.type }}:2021.12.2, | ||||||
|             ghcr.io/goauthentik/${{ matrix.type }}:latest |             ghcr.io/goauthentik/${{ matrix.type }}:latest | ||||||
|           file: ${{ matrix.type }}.Dockerfile |           file: ${{ matrix.type }}.Dockerfile | ||||||
|           platforms: linux/amd64,linux/arm64 |           platforms: linux/amd64,linux/arm64 | ||||||
|       - name: Building Docker Image (stable) |       - name: Building Docker Image (stable) | ||||||
|         if: ${{ github.event_name == 'release' && !contains('2021.12.1-rc2', 'rc') }} |         if: ${{ github.event_name == 'release' && !contains('2021.12.2', 'rc') }} | ||||||
|         run: | |         run: | | ||||||
|           docker pull beryju/authentik-${{ matrix.type }}:latest |           docker pull beryju/authentik-${{ matrix.type }}:latest | ||||||
|           docker tag beryju/authentik-${{ matrix.type }}:latest beryju/authentik-${{ matrix.type }}:stable |           docker tag beryju/authentik-${{ matrix.type }}:latest beryju/authentik-${{ matrix.type }}:stable | ||||||
| @ -93,10 +93,50 @@ jobs: | |||||||
|           docker pull ghcr.io/goauthentik/${{ matrix.type }}:latest |           docker pull ghcr.io/goauthentik/${{ matrix.type }}:latest | ||||||
|           docker tag ghcr.io/goauthentik/${{ matrix.type }}:latest ghcr.io/goauthentik/${{ matrix.type }}:stable |           docker tag ghcr.io/goauthentik/${{ matrix.type }}:latest ghcr.io/goauthentik/${{ matrix.type }}:stable | ||||||
|           docker push ghcr.io/goauthentik/${{ matrix.type }}:stable |           docker push ghcr.io/goauthentik/${{ matrix.type }}:stable | ||||||
|  |   build-outpost-binary: | ||||||
|  |     timeout-minutes: 120 | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     strategy: | ||||||
|  |       fail-fast: false | ||||||
|  |       matrix: | ||||||
|  |         type: | ||||||
|  |           - proxy | ||||||
|  |           - ldap | ||||||
|  |         goos: [linux, darwin] | ||||||
|  |         goarch: [amd64, arm64] | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v2 | ||||||
|  |       - uses: actions/setup-go@v2 | ||||||
|  |         with: | ||||||
|  |           go-version: "^1.17" | ||||||
|  |       - uses: actions/setup-node@v2 | ||||||
|  |         with: | ||||||
|  |           node-version: '16' | ||||||
|  |           cache: 'npm' | ||||||
|  |           cache-dependency-path: web/package-lock.json | ||||||
|  |       - name: Build web | ||||||
|  |         run: | | ||||||
|  |           cd web | ||||||
|  |           npm install | ||||||
|  |           npm run build-proxy | ||||||
|  |       - name: Build outpost | ||||||
|  |         run: | | ||||||
|  |           set -x | ||||||
|  |           export GOOS=${{ matrix.goos }} | ||||||
|  |           export GOARCH=${{ matrix.goarch }} | ||||||
|  |           go build -tags=outpost_static_embed -v -o ./authentik-outpost-${{ matrix.type }}_${{ matrix.goos }}_${{ matrix.goarch }} ./cmd/${{ matrix.type }} | ||||||
|  |       - name: Upload binaries to release | ||||||
|  |         uses: svenstaro/upload-release-action@v2 | ||||||
|  |         with: | ||||||
|  |           repo_token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|  |           file: ./authentik-outpost-${{ matrix.type }}_${{ matrix.goos }}_${{ matrix.goarch }} | ||||||
|  |           asset_name: authentik-outpost-${{ matrix.type }}_${{ matrix.goos }}_${{ matrix.goarch }} | ||||||
|  |           tag: ${{ github.ref }} | ||||||
|   test-release: |   test-release: | ||||||
|     needs: |     needs: | ||||||
|       - build-server |       - build-server | ||||||
|       - build-outpost |       - build-outpost | ||||||
|  |       - build-outpost-binary | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v2 | ||||||
| @ -114,16 +154,11 @@ jobs: | |||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v2 | ||||||
|       - name: Setup Node.js environment |       - name: Get static files from docker image | ||||||
|         uses: actions/setup-node@v2 |  | ||||||
|         with: |  | ||||||
|           node-version: '16' |  | ||||||
|       - name: Build web api client and web ui |  | ||||||
|         run: | |         run: | | ||||||
|           export NODE_ENV=production |           docker pull ghcr.io/goauthentik/server:latest | ||||||
|           cd web |           container=$(docker container create ghcr.io/goauthentik/server:latest) | ||||||
|           npm i |           docker cp ${container}:web/ . | ||||||
|           npm run build |  | ||||||
|       - name: Create a Sentry.io release |       - name: Create a Sentry.io release | ||||||
|         uses: getsentry/action-release@v1 |         uses: getsentry/action-release@v1 | ||||||
|         if: ${{ github.event_name == 'release' }} |         if: ${{ github.event_name == 'release' }} | ||||||
| @ -133,7 +168,7 @@ jobs: | |||||||
|           SENTRY_PROJECT: authentik |           SENTRY_PROJECT: authentik | ||||||
|           SENTRY_URL: https://sentry.beryju.org |           SENTRY_URL: https://sentry.beryju.org | ||||||
|         with: |         with: | ||||||
|           version: authentik@2021.12.1-rc2 |           version: authentik@2021.12.2 | ||||||
|           environment: beryjuorg-prod |           environment: beryjuorg-prod | ||||||
|           sourcemaps: './web/dist' |           sourcemaps: './web/dist' | ||||||
|           url_prefix: '~/static/dist' |           url_prefix: '~/static/dist' | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.python-version
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								.python-version
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | 3.9.7 | ||||||
							
								
								
									
										10
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								Dockerfile
									
									
									
									
									
								
							| @ -1,5 +1,5 @@ | |||||||
| # Stage 1: Lock python dependencies | # Stage 1: Lock python dependencies | ||||||
| FROM docker.io/python:3.9-slim-bullseye as locker | FROM docker.io/python:3.10.1-slim-bullseye as locker | ||||||
|  |  | ||||||
| COPY ./Pipfile /app/ | COPY ./Pipfile /app/ | ||||||
| COPY ./Pipfile.lock /app/ | COPY ./Pipfile.lock /app/ | ||||||
| @ -28,7 +28,7 @@ ENV NODE_ENV=production | |||||||
| RUN cd /work/web && npm i && npm run build | RUN cd /work/web && npm i && npm run build | ||||||
|  |  | ||||||
| # Stage 4: Build go proxy | # Stage 4: Build go proxy | ||||||
| FROM docker.io/golang:1.17.3-bullseye AS builder | FROM docker.io/golang:1.17.5-bullseye AS builder | ||||||
|  |  | ||||||
| WORKDIR /work | WORKDIR /work | ||||||
|  |  | ||||||
| @ -44,7 +44,7 @@ COPY ./go.sum /work/go.sum | |||||||
| RUN go build -o /work/authentik ./cmd/server/main.go | RUN go build -o /work/authentik ./cmd/server/main.go | ||||||
|  |  | ||||||
| # Stage 5: Run | # Stage 5: Run | ||||||
| FROM docker.io/python:3.9-slim-bullseye | FROM docker.io/python:3.10.1-slim-bullseye | ||||||
|  |  | ||||||
| WORKDIR / | WORKDIR / | ||||||
| COPY --from=locker /app/requirements.txt / | COPY --from=locker /app/requirements.txt / | ||||||
| @ -64,8 +64,8 @@ RUN apt-get update && \ | |||||||
|     apt-get clean && \ |     apt-get clean && \ | ||||||
|     rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \ |     rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \ | ||||||
|     adduser --system --no-create-home --uid 1000 --group --home /authentik authentik && \ |     adduser --system --no-create-home --uid 1000 --group --home /authentik authentik && \ | ||||||
|     mkdir /backups /certs && \ |     mkdir -p /backups /certs /media && \ | ||||||
|     chown authentik:authentik /backups /certs |     chown authentik:authentik /backups /certs /media | ||||||
|  |  | ||||||
| COPY ./authentik/ /authentik | COPY ./authentik/ /authentik | ||||||
| COPY ./pyproject.toml / | COPY ./pyproject.toml / | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								Makefile
									
									
									
									
									
								
							| @ -4,13 +4,16 @@ UID = $(shell id -u) | |||||||
| GID = $(shell id -g) | GID = $(shell id -g) | ||||||
| NPM_VERSION = $(shell python -m scripts.npm_version) | NPM_VERSION = $(shell python -m scripts.npm_version) | ||||||
|  |  | ||||||
| all: lint-fix lint test gen | all: lint-fix lint test gen web | ||||||
|  |  | ||||||
| test-integration: | test-integration: | ||||||
| 	coverage run manage.py test tests/integration | 	coverage run manage.py test tests/integration | ||||||
|  |  | ||||||
| test-e2e: | test-e2e-provider: | ||||||
| 	coverage run manage.py test tests/e2e | 	coverage run manage.py test tests/e2e/test_provider* | ||||||
|  |  | ||||||
|  | test-e2e-rest: | ||||||
|  | 	coverage run manage.py test tests/e2e/test_flows* tests/e2e/test_source* | ||||||
|  |  | ||||||
| test: | test: | ||||||
| 	coverage run manage.py test authentik | 	coverage run manage.py test authentik | ||||||
| @ -84,6 +87,9 @@ migrate: | |||||||
| run: | run: | ||||||
| 	go run -v cmd/server/main.go | 	go run -v cmd/server/main.go | ||||||
|  |  | ||||||
|  | web-watch: | ||||||
|  | 	cd web && npm run watch | ||||||
|  |  | ||||||
| web: web-lint-fix web-lint web-extract | web: web-lint-fix web-lint web-extract | ||||||
|  |  | ||||||
| web-lint-fix: | web-lint-fix: | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								Pipfile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Pipfile
									
									
									
									
									
								
							| @ -39,7 +39,7 @@ pycryptodome = "*" | |||||||
| pyjwt = "*" | pyjwt = "*" | ||||||
| pyyaml = "*" | pyyaml = "*" | ||||||
| requests-oauthlib = "*" | requests-oauthlib = "*" | ||||||
| sentry-sdk = "*" | sentry-sdk = { git = 'https://github.com/beryju/sentry-python.git', ref = '379aee28b15d3b87b381317746c4efd24b3d7bc3' } | ||||||
| service_identity = "*" | service_identity = "*" | ||||||
| structlog = "*" | structlog = "*" | ||||||
| swagger-spec-validator = "*" | swagger-spec-validator = "*" | ||||||
| @ -49,6 +49,8 @@ urllib3 = {extras = ["secure"],version = "*"} | |||||||
| uvicorn = {extras = ["standard"],version = "*"} | uvicorn = {extras = ["standard"],version = "*"} | ||||||
| webauthn = "*" | webauthn = "*" | ||||||
| xmlsec = "*" | xmlsec = "*" | ||||||
|  | flower = "*" | ||||||
|  | wsproto = "*" | ||||||
|  |  | ||||||
| [dev-packages] | [dev-packages] | ||||||
| bandit = "*" | bandit = "*" | ||||||
|  | |||||||
							
								
								
									
										741
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										741
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										20
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								README.md
									
									
									
									
									
								
							| @ -38,3 +38,23 @@ See [Development Documentation](https://goauthentik.io/developer-docs/?utm_sourc | |||||||
| ## Security | ## Security | ||||||
|  |  | ||||||
| See [SECURITY.md](SECURITY.md) | See [SECURITY.md](SECURITY.md) | ||||||
|  |  | ||||||
|  | ## Sponsors | ||||||
|  |  | ||||||
|  | This project is proudly sponsored by: | ||||||
|  |  | ||||||
|  | <p> | ||||||
|  |     <a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=goauthentik.io"> | ||||||
|  |         <img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px"> | ||||||
|  |     </a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | DigitalOcean provides development and testing resources for authentik. | ||||||
|  |  | ||||||
|  | <p> | ||||||
|  |     <a href="https://www.netlify.com"> | ||||||
|  |         <img src="https://www.netlify.com/img/global/badges/netlify-color-accent.svg" alt="Deploys by Netlify" /> | ||||||
|  |     </a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | Netlify hosts the [goauthentik.io](goauthentik.io) site. | ||||||
|  | |||||||
| @ -6,8 +6,8 @@ | |||||||
|  |  | ||||||
| | Version    | Supported          | | | Version    | Supported          | | ||||||
| | ---------- | ------------------ | | | ---------- | ------------------ | | ||||||
| | 2021.9.x   | :white_check_mark: | |  | ||||||
| | 2021.10.x  | :white_check_mark: | | | 2021.10.x  | :white_check_mark: | | ||||||
|  | | 2021.12.x  | :white_check_mark: | | ||||||
|  |  | ||||||
| ## Reporting a Vulnerability | ## Reporting a Vulnerability | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,3 +1,3 @@ | |||||||
| """authentik""" | """authentik""" | ||||||
| __version__ = "2021.12.1-rc2" | __version__ = "2021.12.2" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  | |||||||
| @ -1,13 +1,6 @@ | |||||||
| """authentik administration metrics""" | """authentik administration metrics""" | ||||||
| import time |  | ||||||
| from collections import Counter |  | ||||||
| from datetime import timedelta |  | ||||||
|  |  | ||||||
| from django.db.models import Count, ExpressionWrapper, F |  | ||||||
| from django.db.models.fields import DurationField |  | ||||||
| from django.db.models.functions import ExtractHour |  | ||||||
| from django.utils.timezone import now |  | ||||||
| from drf_spectacular.utils import extend_schema, extend_schema_field | from drf_spectacular.utils import extend_schema, extend_schema_field | ||||||
|  | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.fields import IntegerField, SerializerMethodField | from rest_framework.fields import IntegerField, SerializerMethodField | ||||||
| from rest_framework.permissions import IsAdminUser | from rest_framework.permissions import IsAdminUser | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -15,31 +8,7 @@ from rest_framework.response import Response | |||||||
| from rest_framework.views import APIView | from rest_framework.views import APIView | ||||||
|  |  | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import EventAction | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]: |  | ||||||
|     """Get event count by hour in the last day, fill with zeros""" |  | ||||||
|     date_from = now() - timedelta(days=1) |  | ||||||
|     result = ( |  | ||||||
|         Event.objects.filter(created__gte=date_from, **filter_kwargs) |  | ||||||
|         .annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField())) |  | ||||||
|         .annotate(age_hours=ExtractHour("age")) |  | ||||||
|         .values("age_hours") |  | ||||||
|         .annotate(count=Count("pk")) |  | ||||||
|         .order_by("age_hours") |  | ||||||
|     ) |  | ||||||
|     data = Counter({int(d["age_hours"]): d["count"] for d in result}) |  | ||||||
|     results = [] |  | ||||||
|     _now = now() |  | ||||||
|     for hour in range(0, -24, -1): |  | ||||||
|         results.append( |  | ||||||
|             { |  | ||||||
|                 "x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) * 1000, |  | ||||||
|                 "y_cord": data[hour * -1], |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|     return results |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CoordinateSerializer(PassiveSerializer): | class CoordinateSerializer(PassiveSerializer): | ||||||
| @ -58,12 +27,22 @@ class LoginMetricsSerializer(PassiveSerializer): | |||||||
|     @extend_schema_field(CoordinateSerializer(many=True)) |     @extend_schema_field(CoordinateSerializer(many=True)) | ||||||
|     def get_logins_per_1h(self, _): |     def get_logins_per_1h(self, _): | ||||||
|         """Get successful logins per hour for the last 24 hours""" |         """Get successful logins per hour for the last 24 hours""" | ||||||
|         return get_events_per_1h(action=EventAction.LOGIN) |         user = self.context["user"] | ||||||
|  |         return ( | ||||||
|  |             get_objects_for_user(user, "authentik_events.view_event") | ||||||
|  |             .filter(action=EventAction.LOGIN) | ||||||
|  |             .get_events_per_hour() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @extend_schema_field(CoordinateSerializer(many=True)) |     @extend_schema_field(CoordinateSerializer(many=True)) | ||||||
|     def get_logins_failed_per_1h(self, _): |     def get_logins_failed_per_1h(self, _): | ||||||
|         """Get failed logins per hour for the last 24 hours""" |         """Get failed logins per hour for the last 24 hours""" | ||||||
|         return get_events_per_1h(action=EventAction.LOGIN_FAILED) |         user = self.context["user"] | ||||||
|  |         return ( | ||||||
|  |             get_objects_for_user(user, "authentik_events.view_event") | ||||||
|  |             .filter(action=EventAction.LOGIN_FAILED) | ||||||
|  |             .get_events_per_hour() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AdministrationMetricsViewSet(APIView): | class AdministrationMetricsViewSet(APIView): | ||||||
| @ -75,4 +54,5 @@ class AdministrationMetricsViewSet(APIView): | |||||||
|     def get(self, request: Request) -> Response: |     def get(self, request: Request) -> Response: | ||||||
|         """Login Metrics per 1h""" |         """Login Metrics per 1h""" | ||||||
|         serializer = LoginMetricsSerializer(True) |         serializer = LoginMetricsSerializer(True) | ||||||
|  |         serializer.context["user"] = request.user | ||||||
|         return Response(serializer.data) |         return Response(serializer.data) | ||||||
|  | |||||||
| @ -11,7 +11,12 @@ from structlog.stdlib import get_logger | |||||||
|  |  | ||||||
| from authentik import ENV_GIT_HASH_KEY, __version__ | from authentik import ENV_GIT_HASH_KEY, __version__ | ||||||
| from authentik.events.models import Event, EventAction, Notification | from authentik.events.models import Event, EventAction, Notification | ||||||
| from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus | from authentik.events.monitored_tasks import ( | ||||||
|  |     MonitoredTask, | ||||||
|  |     TaskResult, | ||||||
|  |     TaskResultStatus, | ||||||
|  |     prefill_task, | ||||||
|  | ) | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
| @ -48,8 +53,9 @@ def clear_update_notifications(): | |||||||
|             notification.delete() |             notification.delete() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def update_latest_version(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def update_latest_version(self: MonitoredTask): | ||||||
|     """Update latest version info""" |     """Update latest version info""" | ||||||
|     if CONFIG.y_bool("disable_update_check"): |     if CONFIG.y_bool("disable_update_check"): | ||||||
|         cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT) | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ from django.http.response import HttpResponseBadRequest | |||||||
| from django.shortcuts import get_object_or_404 | from django.shortcuts import get_object_or_404 | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
|  | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import ReadOnlyField | from rest_framework.fields import ReadOnlyField | ||||||
| from rest_framework.parsers import MultiPartParser | from rest_framework.parsers import MultiPartParser | ||||||
| @ -15,7 +16,7 @@ from rest_framework.viewsets import ModelViewSet | |||||||
| from rest_framework_guardian.filters import ObjectPermissionsFilter | from rest_framework_guardian.filters import ObjectPermissionsFilter | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.admin.api.metrics import CoordinateSerializer, get_events_per_1h | from authentik.admin.api.metrics import CoordinateSerializer | ||||||
| from authentik.api.decorators import permission_required | from authentik.api.decorators import permission_required | ||||||
| from authentik.core.api.providers import ProviderSerializer | from authentik.core.api.providers import ProviderSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| @ -239,8 +240,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         """Metrics for application logins""" |         """Metrics for application logins""" | ||||||
|         app = self.get_object() |         app = self.get_object() | ||||||
|         return Response( |         return Response( | ||||||
|             get_events_per_1h( |             get_objects_for_user(request.user, "authentik_events.view_event") | ||||||
|  |             .filter( | ||||||
|                 action=EventAction.AUTHORIZE_APPLICATION, |                 action=EventAction.AUTHORIZE_APPLICATION, | ||||||
|                 context__authorized_application__pk=app.pk.hex, |                 context__authorized_application__pk=app.pk.hex, | ||||||
|             ) |             ) | ||||||
|  |             .get_events_per_hour() | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -1,9 +1,11 @@ | |||||||
| """Groups API Viewset""" | """Groups API Viewset""" | ||||||
|  | from json import loads | ||||||
|  |  | ||||||
| from django.db.models.query import QuerySet | from django.db.models.query import QuerySet | ||||||
| from django_filters.filters import ModelMultipleChoiceFilter | from django_filters.filters import CharFilter, ModelMultipleChoiceFilter | ||||||
| from django_filters.filterset import FilterSet | from django_filters.filterset import FilterSet | ||||||
| from rest_framework.fields import CharField, JSONField | from rest_framework.fields import CharField, JSONField | ||||||
| from rest_framework.serializers import ListSerializer, ModelSerializer | from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from rest_framework_guardian.filters import ObjectPermissionsFilter | from rest_framework_guardian.filters import ObjectPermissionsFilter | ||||||
|  |  | ||||||
| @ -62,6 +64,13 @@ class GroupSerializer(ModelSerializer): | |||||||
| class GroupFilter(FilterSet): | class GroupFilter(FilterSet): | ||||||
|     """Filter for groups""" |     """Filter for groups""" | ||||||
|  |  | ||||||
|  |     attributes = CharFilter( | ||||||
|  |         field_name="attributes", | ||||||
|  |         lookup_expr="", | ||||||
|  |         label="Attributes", | ||||||
|  |         method="filter_attributes", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     members_by_username = ModelMultipleChoiceFilter( |     members_by_username = ModelMultipleChoiceFilter( | ||||||
|         field_name="users__username", |         field_name="users__username", | ||||||
|         to_field_name="username", |         to_field_name="username", | ||||||
| @ -72,10 +81,28 @@ class GroupFilter(FilterSet): | |||||||
|         queryset=User.objects.all(), |         queryset=User.objects.all(), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     # pylint: disable=unused-argument | ||||||
|  |     def filter_attributes(self, queryset, name, value): | ||||||
|  |         """Filter attributes by query args""" | ||||||
|  |         try: | ||||||
|  |             value = loads(value) | ||||||
|  |         except ValueError: | ||||||
|  |             raise ValidationError(detail="filter: failed to parse JSON") | ||||||
|  |         if not isinstance(value, dict): | ||||||
|  |             raise ValidationError(detail="filter: value must be key:value mapping") | ||||||
|  |         qs = {} | ||||||
|  |         for key, _value in value.items(): | ||||||
|  |             qs[f"attributes__{key}"] = _value | ||||||
|  |         try: | ||||||
|  |             _ = len(queryset.filter(**qs)) | ||||||
|  |             return queryset.filter(**qs) | ||||||
|  |         except ValueError: | ||||||
|  |             return queryset | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|  |  | ||||||
|         model = Group |         model = Group | ||||||
|         fields = ["name", "is_superuser", "members_by_pk", "members_by_username"] |         fields = ["name", "is_superuser", "members_by_pk", "attributes", "members_by_username"] | ||||||
|  |  | ||||||
|  |  | ||||||
| class GroupViewSet(UsedByMixin, ModelViewSet): | class GroupViewSet(UsedByMixin, ModelViewSet): | ||||||
|  | |||||||
| @ -104,14 +104,14 @@ class SourceViewSet( | |||||||
|         ) |         ) | ||||||
|         matching_sources: list[UserSettingSerializer] = [] |         matching_sources: list[UserSettingSerializer] = [] | ||||||
|         for source in _all_sources: |         for source in _all_sources: | ||||||
|             user_settings = source.ui_user_settings |             user_settings = source.ui_user_settings() | ||||||
|             if not user_settings: |             if not user_settings: | ||||||
|                 continue |                 continue | ||||||
|             policy_engine = PolicyEngine(source, request.user, request) |             policy_engine = PolicyEngine(source, request.user, request) | ||||||
|             policy_engine.build() |             policy_engine.build() | ||||||
|             if not policy_engine.passing: |             if not policy_engine.passing: | ||||||
|                 continue |                 continue | ||||||
|             source_settings = source.ui_user_settings |             source_settings = source.ui_user_settings() | ||||||
|             source_settings.initial_data["object_uid"] = source.slug |             source_settings.initial_data["object_uid"] = source.slug | ||||||
|             if not source_settings.is_valid(): |             if not source_settings.is_valid(): | ||||||
|                 LOGGER.warning(source_settings.errors) |                 LOGGER.warning(source_settings.errors) | ||||||
|  | |||||||
| @ -38,7 +38,7 @@ from rest_framework.viewsets import ModelViewSet | |||||||
| from rest_framework_guardian.filters import ObjectPermissionsFilter | from rest_framework_guardian.filters import ObjectPermissionsFilter | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.admin.api.metrics import CoordinateSerializer, get_events_per_1h | from authentik.admin.api.metrics import CoordinateSerializer | ||||||
| from authentik.api.decorators import permission_required | from authentik.api.decorators import permission_required | ||||||
| from authentik.core.api.groups import GroupSerializer | from authentik.core.api.groups import GroupSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| @ -184,19 +184,31 @@ class UserMetricsSerializer(PassiveSerializer): | |||||||
|     def get_logins_per_1h(self, _): |     def get_logins_per_1h(self, _): | ||||||
|         """Get successful logins per hour for the last 24 hours""" |         """Get successful logins per hour for the last 24 hours""" | ||||||
|         user = self.context["user"] |         user = self.context["user"] | ||||||
|         return get_events_per_1h(action=EventAction.LOGIN, user__pk=user.pk) |         return ( | ||||||
|  |             get_objects_for_user(user, "authentik_events.view_event") | ||||||
|  |             .filter(action=EventAction.LOGIN, user__pk=user.pk) | ||||||
|  |             .get_events_per_hour() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @extend_schema_field(CoordinateSerializer(many=True)) |     @extend_schema_field(CoordinateSerializer(many=True)) | ||||||
|     def get_logins_failed_per_1h(self, _): |     def get_logins_failed_per_1h(self, _): | ||||||
|         """Get failed logins per hour for the last 24 hours""" |         """Get failed logins per hour for the last 24 hours""" | ||||||
|         user = self.context["user"] |         user = self.context["user"] | ||||||
|         return get_events_per_1h(action=EventAction.LOGIN_FAILED, context__username=user.username) |         return ( | ||||||
|  |             get_objects_for_user(user, "authentik_events.view_event") | ||||||
|  |             .filter(action=EventAction.LOGIN_FAILED, context__username=user.username) | ||||||
|  |             .get_events_per_hour() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @extend_schema_field(CoordinateSerializer(many=True)) |     @extend_schema_field(CoordinateSerializer(many=True)) | ||||||
|     def get_authorizations_per_1h(self, _): |     def get_authorizations_per_1h(self, _): | ||||||
|         """Get failed logins per hour for the last 24 hours""" |         """Get failed logins per hour for the last 24 hours""" | ||||||
|         user = self.context["user"] |         user = self.context["user"] | ||||||
|         return get_events_per_1h(action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk) |         return ( | ||||||
|  |             get_objects_for_user(user, "authentik_events.view_event") | ||||||
|  |             .filter(action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk) | ||||||
|  |             .get_events_per_hour() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class UsersFilter(FilterSet): | class UsersFilter(FilterSet): | ||||||
| @ -233,7 +245,11 @@ class UsersFilter(FilterSet): | |||||||
|         qs = {} |         qs = {} | ||||||
|         for key, _value in value.items(): |         for key, _value in value.items(): | ||||||
|             qs[f"attributes__{key}"] = _value |             qs[f"attributes__{key}"] = _value | ||||||
|         return queryset.filter(**qs) |         try: | ||||||
|  |             _ = len(queryset.filter(**qs)) | ||||||
|  |             return queryset.filter(**qs) | ||||||
|  |         except ValueError: | ||||||
|  |             return queryset | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = User |         model = User | ||||||
| @ -314,7 +330,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|                     name=username, |                     name=username, | ||||||
|                     attributes={USER_ATTRIBUTE_SA: True, USER_ATTRIBUTE_TOKEN_EXPIRING: False}, |                     attributes={USER_ATTRIBUTE_SA: True, USER_ATTRIBUTE_TOKEN_EXPIRING: False}, | ||||||
|                 ) |                 ) | ||||||
|                 if create_group: |                 if create_group and self.request.user.has_perm("authentik_core.add_group"): | ||||||
|                     group = Group.objects.create( |                     group = Group.objects.create( | ||||||
|                         name=username, |                         name=username, | ||||||
|                     ) |                     ) | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ from typing import Callable | |||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
|  | from sentry_sdk.api import set_tag | ||||||
|  |  | ||||||
| SESSION_IMPERSONATE_USER = "authentik_impersonate_user" | SESSION_IMPERSONATE_USER = "authentik_impersonate_user" | ||||||
| SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user" | SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user" | ||||||
| @ -50,6 +51,7 @@ class RequestIDMiddleware: | |||||||
|                 "request_id": request_id, |                 "request_id": request_id, | ||||||
|                 "host": request.get_host(), |                 "host": request.get_host(), | ||||||
|             } |             } | ||||||
|  |             set_tag("authentik.request_id", request_id) | ||||||
|         response = self.get_response(request) |         response = self.get_response(request) | ||||||
|         response[RESPONSE_HEADER_ID] = request.request_id |         response[RESPONSE_HEADER_ID] = request.request_id | ||||||
|         setattr(response, "ak_context", {}) |         setattr(response, "ak_context", {}) | ||||||
| @ -65,4 +67,6 @@ def structlog_add_request_id(logger: Logger, method_name: str, event_dict: dict) | |||||||
|     """If threadlocal has authentik defined, add request_id to log""" |     """If threadlocal has authentik defined, add request_id to log""" | ||||||
|     if hasattr(LOCAL, "authentik"): |     if hasattr(LOCAL, "authentik"): | ||||||
|         event_dict.update(LOCAL.authentik) |         event_dict.update(LOCAL.authentik) | ||||||
|  |     if hasattr(LOCAL, "authentik_task"): | ||||||
|  |         event_dict.update(LOCAL.authentik_task) | ||||||
|     return event_dict |     return event_dict | ||||||
|  | |||||||
| @ -25,7 +25,6 @@ from structlog.stdlib import get_logger | |||||||
| from authentik.core.exceptions import PropertyMappingExpressionException | from authentik.core.exceptions import PropertyMappingExpressionException | ||||||
| from authentik.core.signals import password_changed | from authentik.core.signals import password_changed | ||||||
| from authentik.core.types import UILoginButton, UserSettingSerializer | from authentik.core.types import UILoginButton, UserSettingSerializer | ||||||
| from authentik.flows.models import Flow |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.models import CreatedUpdatedModel, DomainlessURLValidator, SerializerModel | from authentik.lib.models import CreatedUpdatedModel, DomainlessURLValidator, SerializerModel | ||||||
| @ -203,7 +202,7 @@ class Provider(SerializerModel): | |||||||
|     name = models.TextField() |     name = models.TextField() | ||||||
|  |  | ||||||
|     authorization_flow = models.ForeignKey( |     authorization_flow = models.ForeignKey( | ||||||
|         Flow, |         "authentik_flows.Flow", | ||||||
|         on_delete=models.CASCADE, |         on_delete=models.CASCADE, | ||||||
|         help_text=_("Flow used when authorizing this provider."), |         help_text=_("Flow used when authorizing this provider."), | ||||||
|         related_name="provider_authorization", |         related_name="provider_authorization", | ||||||
| @ -263,7 +262,7 @@ class Application(PolicyBindingModel): | |||||||
|         it is returned as-is""" |         it is returned as-is""" | ||||||
|         if not self.meta_icon: |         if not self.meta_icon: | ||||||
|             return None |             return None | ||||||
|         if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith("/static"): |         if "://" in self.meta_icon.name or self.meta_icon.name.startswith("/static"): | ||||||
|             return self.meta_icon.name |             return self.meta_icon.name | ||||||
|         return self.meta_icon.url |         return self.meta_icon.url | ||||||
|  |  | ||||||
| @ -279,7 +278,13 @@ class Application(PolicyBindingModel): | |||||||
|         """Get casted provider instance""" |         """Get casted provider instance""" | ||||||
|         if not self.provider: |         if not self.provider: | ||||||
|             return None |             return None | ||||||
|         return Provider.objects.get_subclass(pk=self.provider.pk) |         # if the Application class has been cache, self.provider is set | ||||||
|  |         # but doing a direct query lookup will fail. | ||||||
|  |         # In that case, just return None | ||||||
|  |         try: | ||||||
|  |             return Provider.objects.get_subclass(pk=self.provider.pk) | ||||||
|  |         except Provider.DoesNotExist: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return self.name |         return self.name | ||||||
| @ -324,7 +329,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | |||||||
|     property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True) |     property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True) | ||||||
|  |  | ||||||
|     authentication_flow = models.ForeignKey( |     authentication_flow = models.ForeignKey( | ||||||
|         Flow, |         "authentik_flows.Flow", | ||||||
|         blank=True, |         blank=True, | ||||||
|         null=True, |         null=True, | ||||||
|         default=None, |         default=None, | ||||||
| @ -333,7 +338,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | |||||||
|         related_name="source_authentication", |         related_name="source_authentication", | ||||||
|     ) |     ) | ||||||
|     enrollment_flow = models.ForeignKey( |     enrollment_flow = models.ForeignKey( | ||||||
|         Flow, |         "authentik_flows.Flow", | ||||||
|         blank=True, |         blank=True, | ||||||
|         null=True, |         null=True, | ||||||
|         default=None, |         default=None, | ||||||
| @ -360,13 +365,11 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): | |||||||
|         """Return component used to edit this object""" |         """Return component used to edit this object""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @property |     def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]: | ||||||
|     def ui_login_button(self) -> Optional[UILoginButton]: |  | ||||||
|         """If source uses a http-based flow, return UI Information about the login |         """If source uses a http-based flow, return UI Information about the login | ||||||
|         button. If source doesn't use http-based flow, return None.""" |         button. If source doesn't use http-based flow, return None.""" | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         """Entrypoint to integrate with User settings. Can either return None if no |         """Entrypoint to integrate with User settings. Can either return None if no | ||||||
|         user settings are available, or UserSettingSerializer.""" |         user settings are available, or UserSettingSerializer.""" | ||||||
| @ -453,6 +456,14 @@ class Token(ManagedModel, ExpiringModel): | |||||||
|         """Handler which is called when this object is expired.""" |         """Handler which is called when this object is expired.""" | ||||||
|         from authentik.events.models import Event, EventAction |         from authentik.events.models import Event, EventAction | ||||||
|  |  | ||||||
|  |         if self.intent in [ | ||||||
|  |             TokenIntents.INTENT_RECOVERY, | ||||||
|  |             TokenIntents.INTENT_VERIFICATION, | ||||||
|  |             TokenIntents.INTENT_APP_PASSWORD, | ||||||
|  |         ]: | ||||||
|  |             super().expire_action(*args, **kwargs) | ||||||
|  |             return | ||||||
|  |  | ||||||
|         self.key = default_token_key() |         self.key = default_token_key() | ||||||
|         self.expires = default_token_duration() |         self.expires = default_token_duration() | ||||||
|         self.save(*args, **kwargs) |         self.save(*args, **kwargs) | ||||||
|  | |||||||
| @ -16,15 +16,21 @@ from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession, ExpiringModel | from authentik.core.models import AuthenticatedSession, ExpiringModel | ||||||
| from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus | from authentik.events.monitored_tasks import ( | ||||||
|  |     MonitoredTask, | ||||||
|  |     TaskResult, | ||||||
|  |     TaskResultStatus, | ||||||
|  |     prefill_task, | ||||||
|  | ) | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def clean_expired_models(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def clean_expired_models(self: MonitoredTask): | ||||||
|     """Remove expired objects""" |     """Remove expired objects""" | ||||||
|     messages = [] |     messages = [] | ||||||
|     for cls in ExpiringModel.__subclasses__(): |     for cls in ExpiringModel.__subclasses__(): | ||||||
| @ -62,8 +68,9 @@ def should_backup() -> bool: | |||||||
|     return True |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def backup_database(self: PrefilledMonitoredTask):  # pragma: no cover | @prefill_task | ||||||
|  | def backup_database(self: MonitoredTask):  # pragma: no cover | ||||||
|     """Database backup""" |     """Database backup""" | ||||||
|     self.result_timeout_hours = 25 |     self.result_timeout_hours = 25 | ||||||
|     if not should_backup(): |     if not should_backup(): | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ | |||||||
|         <script src="{% static 'dist/poly.js' %}" type="module"></script> |         <script src="{% static 'dist/poly.js' %}" type="module"></script> | ||||||
|         {% block head %} |         {% block head %} | ||||||
|         {% endblock %} |         {% endblock %} | ||||||
|  |         <meta name="sentry-trace" content="{{ sentry_trace }}" /> | ||||||
|     </head> |     </head> | ||||||
|     <body> |     <body> | ||||||
|         {% block body %} |         {% block body %} | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
| from time import sleep | from time import sleep | ||||||
| from typing import Callable, Type | from typing import Callable, Type | ||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import RequestFactory, TestCase | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
|  |  | ||||||
| @ -30,6 +30,9 @@ class TestModels(TestCase): | |||||||
| def source_tester_factory(test_model: Type[Stage]) -> Callable: | def source_tester_factory(test_model: Type[Stage]) -> Callable: | ||||||
|     """Test source""" |     """Test source""" | ||||||
|  |  | ||||||
|  |     factory = RequestFactory() | ||||||
|  |     request = factory.get("/") | ||||||
|  |  | ||||||
|     def tester(self: TestModels): |     def tester(self: TestModels): | ||||||
|         model_class = None |         model_class = None | ||||||
|         if test_model._meta.abstract: |         if test_model._meta.abstract: | ||||||
| @ -38,8 +41,8 @@ def source_tester_factory(test_model: Type[Stage]) -> Callable: | |||||||
|             model_class = test_model() |             model_class = test_model() | ||||||
|         model_class.slug = "test" |         model_class.slug = "test" | ||||||
|         self.assertIsNotNone(model_class.component) |         self.assertIsNotNone(model_class.component) | ||||||
|         _ = model_class.ui_login_button |         _ = model_class.ui_login_button(request) | ||||||
|         _ = model_class.ui_user_settings |         _ = model_class.ui_user_settings() | ||||||
|  |  | ||||||
|     return tester |     return tester | ||||||
|  |  | ||||||
|  | |||||||
| @ -41,7 +41,7 @@ class TestPropertyMappingAPI(APITestCase): | |||||||
|         expr = "return True" |         expr = "return True" | ||||||
|         self.assertEqual(PropertyMappingSerializer().validate_expression(expr), expr) |         self.assertEqual(PropertyMappingSerializer().validate_expression(expr), expr) | ||||||
|         with self.assertRaises(ValidationError): |         with self.assertRaises(ValidationError): | ||||||
|             print(PropertyMappingSerializer().validate_expression("/")) |             PropertyMappingSerializer().validate_expression("/") | ||||||
|  |  | ||||||
|     def test_types(self): |     def test_types(self): | ||||||
|         """Test PropertyMappigns's types endpoint""" |         """Test PropertyMappigns's types endpoint""" | ||||||
|  | |||||||
| @ -54,7 +54,9 @@ class TestTokenAPI(APITestCase): | |||||||
|  |  | ||||||
|     def test_token_expire(self): |     def test_token_expire(self): | ||||||
|         """Test Token expire task""" |         """Test Token expire task""" | ||||||
|         token: Token = Token.objects.create(expires=now(), user=get_anonymous_user()) |         token: Token = Token.objects.create( | ||||||
|  |             expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API | ||||||
|  |         ) | ||||||
|         key = token.key |         key = token.key | ||||||
|         clean_expired_models.delay().get() |         clean_expired_models.delay().get() | ||||||
|         token.refresh_from_db() |         token.refresh_from_db() | ||||||
|  | |||||||
| @ -192,7 +192,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|             secret=certificate, |             secret=certificate, | ||||||
|             type="certificate", |             type="certificate", | ||||||
|         ).from_http(request) |         ).from_http(request) | ||||||
|         if "download" in request._request.GET: |         if "download" in request.query_params: | ||||||
|             # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html |             # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html | ||||||
|             response = HttpResponse( |             response = HttpResponse( | ||||||
|                 certificate.certificate_data, content_type="application/x-pem-file" |                 certificate.certificate_data, content_type="application/x-pem-file" | ||||||
| @ -223,7 +223,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|             secret=certificate, |             secret=certificate, | ||||||
|             type="private_key", |             type="private_key", | ||||||
|         ).from_http(request) |         ).from_http(request) | ||||||
|         if "download" in request._request.GET: |         if "download" in request.query_params: | ||||||
|             # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html |             # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html | ||||||
|             response = HttpResponse(certificate.key_data, content_type="application/x-pem-file") |             response = HttpResponse(certificate.key_data, content_type="application/x-pem-file") | ||||||
|             response[ |             response[ | ||||||
|  | |||||||
| @ -11,10 +11,13 @@ from cryptography.hazmat.primitives.serialization import load_pem_private_key | |||||||
| from cryptography.x509 import Certificate, load_pem_x509_certificate | from cryptography.x509 import Certificate, load_pem_x509_certificate | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.lib.models import CreatedUpdatedModel | from authentik.lib.models import CreatedUpdatedModel | ||||||
| from authentik.managed.models import ManagedModel | from authentik.managed.models import ManagedModel | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| class CertificateKeyPair(ManagedModel, CreatedUpdatedModel): | class CertificateKeyPair(ManagedModel, CreatedUpdatedModel): | ||||||
|     """CertificateKeyPair that can be used for signing or encrypting if `key_data` |     """CertificateKeyPair that can be used for signing or encrypting if `key_data` | ||||||
| @ -62,7 +65,8 @@ class CertificateKeyPair(ManagedModel, CreatedUpdatedModel): | |||||||
|                     password=None, |                     password=None, | ||||||
|                     backend=default_backend(), |                     backend=default_backend(), | ||||||
|                 ) |                 ) | ||||||
|             except ValueError: |             except ValueError as exc: | ||||||
|  |                 LOGGER.warning(exc) | ||||||
|                 return None |                 return None | ||||||
|         return self._private_key |         return self._private_key | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,11 +2,19 @@ | |||||||
| from glob import glob | from glob import glob | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | from cryptography.hazmat.backends import default_backend | ||||||
|  | from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||||
|  | from cryptography.x509.base import load_pem_x509_certificate | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus | from authentik.events.monitored_tasks import ( | ||||||
|  |     MonitoredTask, | ||||||
|  |     TaskResult, | ||||||
|  |     TaskResultStatus, | ||||||
|  |     prefill_task, | ||||||
|  | ) | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| @ -15,8 +23,25 @@ LOGGER = get_logger() | |||||||
| MANAGED_DISCOVERED = "goauthentik.io/crypto/discovered/%s" | MANAGED_DISCOVERED = "goauthentik.io/crypto/discovered/%s" | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | def ensure_private_key_valid(body: str): | ||||||
| def certificate_discovery(self: PrefilledMonitoredTask): |     """Attempt loading of an RSA Private key without password""" | ||||||
|  |     load_pem_private_key( | ||||||
|  |         str.encode("\n".join([x.strip() for x in body.split("\n")])), | ||||||
|  |         password=None, | ||||||
|  |         backend=default_backend(), | ||||||
|  |     ) | ||||||
|  |     return body | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def ensure_certificate_valid(body: str): | ||||||
|  |     """Attempt loading of a PEM-encoded certificate""" | ||||||
|  |     load_pem_x509_certificate(body.encode("utf-8"), default_backend()) | ||||||
|  |     return body | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
|  | @prefill_task | ||||||
|  | def certificate_discovery(self: MonitoredTask): | ||||||
|     """Discover and update certificates form the filesystem""" |     """Discover and update certificates form the filesystem""" | ||||||
|     certs = {} |     certs = {} | ||||||
|     private_keys = {} |     private_keys = {} | ||||||
| @ -36,11 +61,11 @@ def certificate_discovery(self: PrefilledMonitoredTask): | |||||||
|             with open(path, "r+", encoding="utf-8") as _file: |             with open(path, "r+", encoding="utf-8") as _file: | ||||||
|                 body = _file.read() |                 body = _file.read() | ||||||
|                 if "BEGIN RSA PRIVATE KEY" in body: |                 if "BEGIN RSA PRIVATE KEY" in body: | ||||||
|                     private_keys[cert_name] = body |                     private_keys[cert_name] = ensure_private_key_valid(body) | ||||||
|                 else: |                 else: | ||||||
|                     certs[cert_name] = body |                     certs[cert_name] = ensure_certificate_valid(body) | ||||||
|         except OSError as exc: |         except (OSError, ValueError) as exc: | ||||||
|             LOGGER.warning("Failed to open file", exc=exc, file=path) |             LOGGER.warning("Failed to open file or invalid format", exc=exc, file=path) | ||||||
|         discovered += 1 |         discovered += 1 | ||||||
|     for name, cert_data in certs.items(): |     for name, cert_data in certs.items(): | ||||||
|         cert = CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % name).first() |         cert = CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % name).first() | ||||||
| @ -54,7 +79,7 @@ def certificate_discovery(self: PrefilledMonitoredTask): | |||||||
|             cert.certificate_data = cert_data |             cert.certificate_data = cert_data | ||||||
|             dirty = True |             dirty = True | ||||||
|         if name in private_keys: |         if name in private_keys: | ||||||
|             if cert.key_data == private_keys[name]: |             if cert.key_data != private_keys[name]: | ||||||
|                 cert.key_data = private_keys[name] |                 cert.key_data = private_keys[name] | ||||||
|                 dirty = True |                 dirty = True | ||||||
|         if dirty: |         if dirty: | ||||||
|  | |||||||
| @ -191,9 +191,12 @@ class TestCrypto(APITestCase): | |||||||
|             with CONFIG.patch("cert_discovery_dir", temp_dir): |             with CONFIG.patch("cert_discovery_dir", temp_dir): | ||||||
|                 # pyright: reportGeneralTypeIssues=false |                 # pyright: reportGeneralTypeIssues=false | ||||||
|                 certificate_discovery()  # pylint: disable=no-value-for-parameter |                 certificate_discovery()  # pylint: disable=no-value-for-parameter | ||||||
|         self.assertTrue( |         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( | ||||||
|             CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % "foo").exists() |             managed=MANAGED_DISCOVERED % "foo" | ||||||
|         ) |         ).first() | ||||||
|  |         self.assertIsNotNone(keypair) | ||||||
|  |         self.assertIsNotNone(keypair.certificate) | ||||||
|  |         self.assertIsNotNone(keypair.private_key) | ||||||
|         self.assertTrue( |         self.assertTrue( | ||||||
|             CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % "foo.bar").exists() |             CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % "foo.bar").exists() | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -1,4 +1,6 @@ | |||||||
| """Events API Views""" | """Events API Views""" | ||||||
|  | from json import loads | ||||||
|  |  | ||||||
| import django_filters | import django_filters | ||||||
| from django.db.models.aggregates import Count | from django.db.models.aggregates import Count | ||||||
| from django.db.models.fields.json import KeyTextTransform | from django.db.models.fields.json import KeyTextTransform | ||||||
| @ -12,6 +14,7 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ModelSerializer | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
|  | from authentik.admin.api.metrics import CoordinateSerializer | ||||||
| from authentik.core.api.utils import PassiveSerializer, TypeCreateSerializer | from authentik.core.api.utils import PassiveSerializer, TypeCreateSerializer | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
|  |  | ||||||
| @ -110,13 +113,20 @@ class EventViewSet(ModelViewSet): | |||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         methods=["GET"], |         methods=["GET"], | ||||||
|         responses={200: EventTopPerUserSerializer(many=True)}, |         responses={200: EventTopPerUserSerializer(many=True)}, | ||||||
|  |         filters=[], | ||||||
|         parameters=[ |         parameters=[ | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 "action", | ||||||
|  |                 type=OpenApiTypes.STR, | ||||||
|  |                 location=OpenApiParameter.QUERY, | ||||||
|  |                 required=False, | ||||||
|  |             ), | ||||||
|             OpenApiParameter( |             OpenApiParameter( | ||||||
|                 "top_n", |                 "top_n", | ||||||
|                 type=OpenApiTypes.INT, |                 type=OpenApiTypes.INT, | ||||||
|                 location=OpenApiParameter.QUERY, |                 location=OpenApiParameter.QUERY, | ||||||
|                 required=False, |                 required=False, | ||||||
|             ) |             ), | ||||||
|         ], |         ], | ||||||
|     ) |     ) | ||||||
|     @action(detail=False, methods=["GET"], pagination_class=None) |     @action(detail=False, methods=["GET"], pagination_class=None) | ||||||
| @ -137,6 +147,40 @@ class EventViewSet(ModelViewSet): | |||||||
|             .order_by("-counted_events")[:top_n] |             .order_by("-counted_events")[:top_n] | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @extend_schema( | ||||||
|  |         methods=["GET"], | ||||||
|  |         responses={200: CoordinateSerializer(many=True)}, | ||||||
|  |         filters=[], | ||||||
|  |         parameters=[ | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 "action", | ||||||
|  |                 type=OpenApiTypes.STR, | ||||||
|  |                 location=OpenApiParameter.QUERY, | ||||||
|  |                 required=False, | ||||||
|  |             ), | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 "query", | ||||||
|  |                 type=OpenApiTypes.STR, | ||||||
|  |                 location=OpenApiParameter.QUERY, | ||||||
|  |                 required=False, | ||||||
|  |             ), | ||||||
|  |         ], | ||||||
|  |     ) | ||||||
|  |     @action(detail=False, methods=["GET"], pagination_class=None) | ||||||
|  |     def per_month(self, request: Request): | ||||||
|  |         """Get the count of events per month""" | ||||||
|  |         filtered_action = request.query_params.get("action", EventAction.LOGIN) | ||||||
|  |         try: | ||||||
|  |             query = loads(request.query_params.get("query", "{}")) | ||||||
|  |         except ValueError: | ||||||
|  |             return Response(status=400) | ||||||
|  |         return Response( | ||||||
|  |             get_objects_for_user(request.user, "authentik_events.view_event") | ||||||
|  |             .filter(action=filtered_action) | ||||||
|  |             .filter(**query) | ||||||
|  |             .get_events_per_day() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @extend_schema(responses={200: TypeCreateSerializer(many=True)}) |     @extend_schema(responses={200: TypeCreateSerializer(many=True)}) | ||||||
|     @action(detail=False, pagination_class=None, filter_backends=[]) |     @action(detail=False, pagination_class=None, filter_backends=[]) | ||||||
|     def actions(self, request: Request) -> Response: |     def actions(self, request: Request) -> Response: | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ from typing import Optional, TypedDict | |||||||
| from geoip2.database import Reader | from geoip2.database import Reader | ||||||
| from geoip2.errors import GeoIP2Error | from geoip2.errors import GeoIP2Error | ||||||
| from geoip2.models import City | from geoip2.models import City | ||||||
|  | from sentry_sdk.hub import Hub | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| @ -62,13 +63,17 @@ class GeoIPReader: | |||||||
|  |  | ||||||
|     def city(self, ip_address: str) -> Optional[City]: |     def city(self, ip_address: str) -> Optional[City]: | ||||||
|         """Wrapper for Reader.city""" |         """Wrapper for Reader.city""" | ||||||
|         if not self.enabled: |         with Hub.current.start_span( | ||||||
|             return None |             op="authentik.events.geo.city", | ||||||
|         self.__check_expired() |             description=ip_address, | ||||||
|         try: |         ): | ||||||
|             return self.__reader.city(ip_address) |             if not self.enabled: | ||||||
|         except (GeoIP2Error, ValueError): |                 return None | ||||||
|             return None |             self.__check_expired() | ||||||
|  |             try: | ||||||
|  |                 return self.__reader.city(ip_address) | ||||||
|  |             except (GeoIP2Error, ValueError): | ||||||
|  |                 return None | ||||||
|  |  | ||||||
|     def city_dict(self, ip_address: str) -> Optional[GeoIPDict]: |     def city_dict(self, ip_address: str) -> Optional[GeoIPDict]: | ||||||
|         """Wrapper for self.city that returns a dict""" |         """Wrapper for self.city that returns a dict""" | ||||||
|  | |||||||
| @ -314,169 +314,10 @@ class Migration(migrations.Migration): | |||||||
|             old_name="user_json", |             old_name="user_json", | ||||||
|             new_name="user", |             new_name="user", | ||||||
|         ), |         ), | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("sign_up", "Sign Up"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("invitation_created", "Invite Created"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("invitation_created", "Invite Created"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.RemoveField( |         migrations.RemoveField( | ||||||
|             model_name="event", |             model_name="event", | ||||||
|             name="date", |             name="date", | ||||||
|         ), |         ), | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("token_view", "Token View"), |  | ||||||
|                     ("invitation_created", "Invite Created"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("token_view", "Token View"), |  | ||||||
|                     ("invitation_created", "Invite Created"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("policy_execution", "Policy Execution"), |  | ||||||
|                     ("policy_exception", "Policy Exception"), |  | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("token_view", "Token View"), |  | ||||||
|                     ("invitation_created", "Invite Created"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("policy_execution", "Policy Execution"), |  | ||||||
|                     ("policy_exception", "Policy Exception"), |  | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("update_available", "Update Available"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("token_view", "Token View"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("policy_execution", "Policy Execution"), |  | ||||||
|                     ("policy_exception", "Policy Exception"), |  | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |  | ||||||
|                     ("configuration_error", "Configuration Error"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("update_available", "Update Available"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.CreateModel( |         migrations.CreateModel( | ||||||
|             name="NotificationTransport", |             name="NotificationTransport", | ||||||
|             fields=[ |             fields=[ | ||||||
| @ -610,68 +451,6 @@ class Migration(migrations.Migration): | |||||||
|                 help_text="Only send notification once, for example when sending a webhook into a chat channel.", |                 help_text="Only send notification once, for example when sending a webhook into a chat channel.", | ||||||
|             ), |             ), | ||||||
|         ), |         ), | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("token_view", "Token View"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("policy_execution", "Policy Execution"), |  | ||||||
|                     ("policy_exception", "Policy Exception"), |  | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |  | ||||||
|                     ("system_task_execution", "System Task Execution"), |  | ||||||
|                     ("system_task_exception", "System Task Exception"), |  | ||||||
|                     ("configuration_error", "Configuration Error"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("update_available", "Update Available"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("secret_view", "Secret View"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("policy_execution", "Policy Execution"), |  | ||||||
|                     ("policy_exception", "Policy Exception"), |  | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |  | ||||||
|                     ("system_task_execution", "System Task Execution"), |  | ||||||
|                     ("system_task_exception", "System Task Exception"), |  | ||||||
|                     ("configuration_error", "Configuration Error"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("update_available", "Update Available"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.RunPython( |         migrations.RunPython( | ||||||
|             code=token_view_to_secret_view, |             code=token_view_to_secret_view, | ||||||
|         ), |         ), | ||||||
| @ -688,76 +467,11 @@ class Migration(migrations.Migration): | |||||||
|         migrations.RunPython( |         migrations.RunPython( | ||||||
|             code=update_expires, |             code=update_expires, | ||||||
|         ), |         ), | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("secret_view", "Secret View"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("policy_execution", "Policy Execution"), |  | ||||||
|                     ("policy_exception", "Policy Exception"), |  | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |  | ||||||
|                     ("system_task_execution", "System Task Execution"), |  | ||||||
|                     ("system_task_exception", "System Task Exception"), |  | ||||||
|                     ("configuration_error", "Configuration Error"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("email_sent", "Email Sent"), |  | ||||||
|                     ("update_available", "Update Available"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="event", |             model_name="event", | ||||||
|             name="tenant", |             name="tenant", | ||||||
|             field=models.JSONField(blank=True, default=authentik.events.models.default_tenant), |             field=models.JSONField(blank=True, default=authentik.events.models.default_tenant), | ||||||
|         ), |         ), | ||||||
|         migrations.AlterField( |  | ||||||
|             model_name="event", |  | ||||||
|             name="action", |  | ||||||
|             field=models.TextField( |  | ||||||
|                 choices=[ |  | ||||||
|                     ("login", "Login"), |  | ||||||
|                     ("login_failed", "Login Failed"), |  | ||||||
|                     ("logout", "Logout"), |  | ||||||
|                     ("user_write", "User Write"), |  | ||||||
|                     ("suspicious_request", "Suspicious Request"), |  | ||||||
|                     ("password_set", "Password Set"), |  | ||||||
|                     ("secret_view", "Secret View"), |  | ||||||
|                     ("invitation_used", "Invite Used"), |  | ||||||
|                     ("authorize_application", "Authorize Application"), |  | ||||||
|                     ("source_linked", "Source Linked"), |  | ||||||
|                     ("impersonation_started", "Impersonation Started"), |  | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |  | ||||||
|                     ("policy_execution", "Policy Execution"), |  | ||||||
|                     ("policy_exception", "Policy Exception"), |  | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |  | ||||||
|                     ("system_task_execution", "System Task Execution"), |  | ||||||
|                     ("system_task_exception", "System Task Exception"), |  | ||||||
|                     ("system_exception", "System Exception"), |  | ||||||
|                     ("configuration_error", "Configuration Error"), |  | ||||||
|                     ("model_created", "Model Created"), |  | ||||||
|                     ("model_updated", "Model Updated"), |  | ||||||
|                     ("model_deleted", "Model Deleted"), |  | ||||||
|                     ("email_sent", "Email Sent"), |  | ||||||
|                     ("update_available", "Update Available"), |  | ||||||
|                     ("custom_", "Custom Prefix"), |  | ||||||
|                 ] |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="event", |             model_name="event", | ||||||
|             name="action", |             name="action", | ||||||
| @ -776,6 +490,7 @@ class Migration(migrations.Migration): | |||||||
|                     ("source_linked", "Source Linked"), |                     ("source_linked", "Source Linked"), | ||||||
|                     ("impersonation_started", "Impersonation Started"), |                     ("impersonation_started", "Impersonation Started"), | ||||||
|                     ("impersonation_ended", "Impersonation Ended"), |                     ("impersonation_ended", "Impersonation Ended"), | ||||||
|  |                     ("flow_execution", "Flow Execution"), | ||||||
|                     ("policy_execution", "Policy Execution"), |                     ("policy_execution", "Policy Execution"), | ||||||
|                     ("policy_exception", "Policy Exception"), |                     ("policy_exception", "Policy Exception"), | ||||||
|                     ("property_mapping_exception", "Property Mapping Exception"), |                     ("property_mapping_exception", "Property Mapping Exception"), | ||||||
|  | |||||||
| @ -1,12 +1,20 @@ | |||||||
| """authentik events models""" | """authentik events models""" | ||||||
|  | import time | ||||||
|  | from collections import Counter | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from inspect import getmodule, stack | from inspect import currentframe | ||||||
| from smtplib import SMTPException | from smtplib import SMTPException | ||||||
| from typing import TYPE_CHECKING, Optional, Type, Union | from typing import TYPE_CHECKING, Optional, Type, Union | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db import models | from django.db import models | ||||||
|  | from django.db.models import Count, ExpressionWrapper, F | ||||||
|  | from django.db.models.fields import DurationField | ||||||
|  | from django.db.models.functions import ExtractHour | ||||||
|  | from django.db.models.functions.datetime import ExtractDay | ||||||
|  | from django.db.models.manager import Manager | ||||||
|  | from django.db.models.query import QuerySet | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from django.http.request import QueryDict | from django.http.request import QueryDict | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| @ -70,6 +78,7 @@ class EventAction(models.TextChoices): | |||||||
|     IMPERSONATION_STARTED = "impersonation_started" |     IMPERSONATION_STARTED = "impersonation_started" | ||||||
|     IMPERSONATION_ENDED = "impersonation_ended" |     IMPERSONATION_ENDED = "impersonation_ended" | ||||||
|  |  | ||||||
|  |     FLOW_EXECUTION = "flow_execution" | ||||||
|     POLICY_EXECUTION = "policy_execution" |     POLICY_EXECUTION = "policy_execution" | ||||||
|     POLICY_EXCEPTION = "policy_exception" |     POLICY_EXCEPTION = "policy_exception" | ||||||
|     PROPERTY_MAPPING_EXCEPTION = "property_mapping_exception" |     PROPERTY_MAPPING_EXCEPTION = "property_mapping_exception" | ||||||
| @ -90,6 +99,72 @@ class EventAction(models.TextChoices): | |||||||
|     CUSTOM_PREFIX = "custom_" |     CUSTOM_PREFIX = "custom_" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class EventQuerySet(QuerySet): | ||||||
|  |     """Custom events query set with helper functions""" | ||||||
|  |  | ||||||
|  |     def get_events_per_hour(self) -> list[dict[str, int]]: | ||||||
|  |         """Get event count by hour in the last day, fill with zeros""" | ||||||
|  |         date_from = now() - timedelta(days=1) | ||||||
|  |         result = ( | ||||||
|  |             self.filter(created__gte=date_from) | ||||||
|  |             .annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField())) | ||||||
|  |             .annotate(age_hours=ExtractHour("age")) | ||||||
|  |             .values("age_hours") | ||||||
|  |             .annotate(count=Count("pk")) | ||||||
|  |             .order_by("age_hours") | ||||||
|  |         ) | ||||||
|  |         data = Counter({int(d["age_hours"]): d["count"] for d in result}) | ||||||
|  |         results = [] | ||||||
|  |         _now = now() | ||||||
|  |         for hour in range(0, -24, -1): | ||||||
|  |             results.append( | ||||||
|  |                 { | ||||||
|  |                     "x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) * 1000, | ||||||
|  |                     "y_cord": data[hour * -1], | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |     def get_events_per_day(self) -> list[dict[str, int]]: | ||||||
|  |         """Get event count by hour in the last day, fill with zeros""" | ||||||
|  |         date_from = now() - timedelta(weeks=4) | ||||||
|  |         result = ( | ||||||
|  |             self.filter(created__gte=date_from) | ||||||
|  |             .annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField())) | ||||||
|  |             .annotate(age_days=ExtractDay("age")) | ||||||
|  |             .values("age_days") | ||||||
|  |             .annotate(count=Count("pk")) | ||||||
|  |             .order_by("age_days") | ||||||
|  |         ) | ||||||
|  |         data = Counter({int(d["age_days"]): d["count"] for d in result}) | ||||||
|  |         results = [] | ||||||
|  |         _now = now() | ||||||
|  |         for day in range(0, -30, -1): | ||||||
|  |             results.append( | ||||||
|  |                 { | ||||||
|  |                     "x_cord": time.mktime((_now + timedelta(days=day)).timetuple()) * 1000, | ||||||
|  |                     "y_cord": data[day * -1], | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class EventManager(Manager): | ||||||
|  |     """Custom helper methods for Events""" | ||||||
|  |  | ||||||
|  |     def get_queryset(self) -> QuerySet: | ||||||
|  |         """use custom queryset""" | ||||||
|  |         return EventQuerySet(self.model, using=self._db) | ||||||
|  |  | ||||||
|  |     def get_events_per_hour(self) -> list[dict[str, int]]: | ||||||
|  |         """Wrap method from queryset""" | ||||||
|  |         return self.get_queryset().get_events_per_hour() | ||||||
|  |  | ||||||
|  |     def get_events_per_day(self) -> list[dict[str, int]]: | ||||||
|  |         """Wrap method from queryset""" | ||||||
|  |         return self.get_queryset().get_events_per_day() | ||||||
|  |  | ||||||
|  |  | ||||||
| class Event(ExpiringModel): | class Event(ExpiringModel): | ||||||
|     """An individual Audit/Metrics/Notification/Error Event""" |     """An individual Audit/Metrics/Notification/Error Event""" | ||||||
|  |  | ||||||
| @ -105,6 +180,8 @@ class Event(ExpiringModel): | |||||||
|     # Shadow the expires attribute from ExpiringModel to override the default duration |     # Shadow the expires attribute from ExpiringModel to override the default duration | ||||||
|     expires = models.DateTimeField(default=default_event_duration) |     expires = models.DateTimeField(default=default_event_duration) | ||||||
|  |  | ||||||
|  |     objects = EventManager() | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _get_app_from_request(request: HttpRequest) -> str: |     def _get_app_from_request(request: HttpRequest) -> str: | ||||||
|         if not isinstance(request, HttpRequest): |         if not isinstance(request, HttpRequest): | ||||||
| @ -115,14 +192,15 @@ class Event(ExpiringModel): | |||||||
|     def new( |     def new( | ||||||
|         action: Union[str, EventAction], |         action: Union[str, EventAction], | ||||||
|         app: Optional[str] = None, |         app: Optional[str] = None, | ||||||
|         _inspect_offset: int = 1, |  | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> "Event": |     ) -> "Event": | ||||||
|         """Create new Event instance from arguments. Instance is NOT saved.""" |         """Create new Event instance from arguments. Instance is NOT saved.""" | ||||||
|         if not isinstance(action, EventAction): |         if not isinstance(action, EventAction): | ||||||
|             action = EventAction.CUSTOM_PREFIX + action |             action = EventAction.CUSTOM_PREFIX + action | ||||||
|         if not app: |         if not app: | ||||||
|             app = getmodule(stack()[_inspect_offset][0]).__name__ |             current = currentframe() | ||||||
|  |             parent = current.f_back | ||||||
|  |             app = parent.f_globals["__name__"] | ||||||
|         cleaned_kwargs = cleanse_dict(sanitize_dict(kwargs)) |         cleaned_kwargs = cleanse_dict(sanitize_dict(kwargs)) | ||||||
|         event = Event(action=action, app=app, context=cleaned_kwargs) |         event = Event(action=action, app=app, context=cleaned_kwargs) | ||||||
|         return event |         return event | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ class TaskResult: | |||||||
|  |  | ||||||
|     def with_error(self, exc: Exception) -> "TaskResult": |     def with_error(self, exc: Exception) -> "TaskResult": | ||||||
|         """Since errors might not always be pickle-able, set the traceback""" |         """Since errors might not always be pickle-able, set the traceback""" | ||||||
|         self.messages.extend(exception_to_string(exc).splitlines()) |         self.messages.append(str(exc)) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -186,31 +186,21 @@ class MonitoredTask(Task): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrefilledMonitoredTask(MonitoredTask): | def prefill_task(func): | ||||||
|     """Subclass of MonitoredTask, but create entry in cache if task hasn't been run |     """Ensure a task's details are always in cache, so it can always be triggered via API""" | ||||||
|     Does not support UID""" |     status = TaskInfo.by_name(func.__name__) | ||||||
|  |     if status: | ||||||
|     def __init__(self, *args, **kwargs) -> None: |         return func | ||||||
|         super().__init__(*args, **kwargs) |     TaskInfo( | ||||||
|         status = TaskInfo.by_name(self.__name__) |         task_name=func.__name__, | ||||||
|         if status: |         task_description=func.__doc__, | ||||||
|             return |         result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]), | ||||||
|         TaskInfo( |         task_call_module=func.__module__, | ||||||
|             task_name=self.__name__, |         task_call_func=func.__name__, | ||||||
|             task_description=self.__doc__, |         # We don't have real values for these attributes but they cannot be null | ||||||
|             result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]), |         start_timestamp=default_timer(), | ||||||
|             task_call_module=self.__module__, |         finish_timestamp=default_timer(), | ||||||
|             task_call_func=self.__name__, |         finish_time=datetime.now(), | ||||||
|             # We don't have real values for these attributes but they cannot be null |     ).save(86400) | ||||||
|             start_timestamp=default_timer(), |     LOGGER.debug("prefilled task", task_name=func.__name__) | ||||||
|             finish_timestamp=default_timer(), |     return func | ||||||
|             finish_time=datetime.now(), |  | ||||||
|         ).save(86400) |  | ||||||
|         LOGGER.debug("prefilled task", task_name=self.__name__) |  | ||||||
|  |  | ||||||
|     def run(self, *args, **kwargs): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|  |  | ||||||
| for task in TaskInfo.all().values(): |  | ||||||
|     task.set_prom_metrics() |  | ||||||
|  | |||||||
| @ -90,7 +90,7 @@ class StageViewSet( | |||||||
|             stages += list(configurable_stage.objects.all().order_by("name")) |             stages += list(configurable_stage.objects.all().order_by("name")) | ||||||
|         matching_stages: list[dict] = [] |         matching_stages: list[dict] = [] | ||||||
|         for stage in stages: |         for stage in stages: | ||||||
|             user_settings = stage.ui_user_settings |             user_settings = stage.ui_user_settings() | ||||||
|             if not user_settings: |             if not user_settings: | ||||||
|                 continue |                 continue | ||||||
|             user_settings.initial_data["object_uid"] = str(stage.pk) |             user_settings.initial_data["object_uid"] = str(stage.pk) | ||||||
|  | |||||||
							
								
								
									
										46
									
								
								authentik/flows/migrations/0020_flowtoken.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								authentik/flows/migrations/0020_flowtoken.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,46 @@ | |||||||
|  | # Generated by Django 3.2.9 on 2021-12-05 13:50 | ||||||
|  |  | ||||||
|  | import django.db.models.deletion | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_core", "0018_auto_20210330_1345_squashed_0028_alter_token_intent"), | ||||||
|  |         ( | ||||||
|  |             "authentik_flows", | ||||||
|  |             "0019_alter_flow_background_squashed_0024_alter_flow_compatibility_mode", | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             name="FlowToken", | ||||||
|  |             fields=[ | ||||||
|  |                 ( | ||||||
|  |                     "token_ptr", | ||||||
|  |                     models.OneToOneField( | ||||||
|  |                         auto_created=True, | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         parent_link=True, | ||||||
|  |                         primary_key=True, | ||||||
|  |                         serialize=False, | ||||||
|  |                         to="authentik_core.token", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ("_plan", models.TextField()), | ||||||
|  |                 ( | ||||||
|  |                     "flow", | ||||||
|  |                     models.ForeignKey( | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, to="authentik_flows.flow" | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |             ], | ||||||
|  |             options={ | ||||||
|  |                 "verbose_name": "Flow Token", | ||||||
|  |                 "verbose_name_plural": "Flow Tokens", | ||||||
|  |             }, | ||||||
|  |             bases=("authentik_core.token",), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -1,4 +1,6 @@ | |||||||
| """Flow models""" | """Flow models""" | ||||||
|  | from base64 import b64decode, b64encode | ||||||
|  | from pickle import dumps, loads  # nosec | ||||||
| from typing import TYPE_CHECKING, Optional, Type | from typing import TYPE_CHECKING, Optional, Type | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| @ -9,11 +11,13 @@ from model_utils.managers import InheritanceManager | |||||||
| from rest_framework.serializers import BaseSerializer | from rest_framework.serializers import BaseSerializer | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.core.models import Token | ||||||
| from authentik.core.types import UserSettingSerializer | from authentik.core.types import UserSettingSerializer | ||||||
| from authentik.lib.models import InheritanceForeignKey, SerializerModel | from authentik.lib.models import InheritanceForeignKey, SerializerModel | ||||||
| from authentik.policies.models import PolicyBindingModel | from authentik.policies.models import PolicyBindingModel | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|  |     from authentik.flows.planner import FlowPlan | ||||||
|     from authentik.flows.stage import StageView |     from authentik.flows.stage import StageView | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -71,7 +75,6 @@ class Stage(SerializerModel): | |||||||
|         """Return component used to edit this object""" |         """Return component used to edit this object""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         """Entrypoint to integrate with User settings. Can either return None if no |         """Entrypoint to integrate with User settings. Can either return None if no | ||||||
|         user settings are available, or a challenge.""" |         user settings are available, or a challenge.""" | ||||||
| @ -260,3 +263,30 @@ class ConfigurableStage(models.Model): | |||||||
|     class Meta: |     class Meta: | ||||||
|  |  | ||||||
|         abstract = True |         abstract = True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FlowToken(Token): | ||||||
|  |     """Subclass of a standard Token, stores the currently active flow plan upon creation. | ||||||
|  |     Can be used to later resume a flow.""" | ||||||
|  |  | ||||||
|  |     flow = models.ForeignKey(Flow, on_delete=models.CASCADE) | ||||||
|  |     _plan = models.TextField() | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def pickle(plan) -> str: | ||||||
|  |         """Pickle into string""" | ||||||
|  |         data = dumps(plan) | ||||||
|  |         return b64encode(data).decode() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def plan(self) -> "FlowPlan": | ||||||
|  |         """Load Flow plan from pickled version""" | ||||||
|  |         return loads(b64decode(self._plan.encode()))  # nosec | ||||||
|  |  | ||||||
|  |     def __str__(self) -> str: | ||||||
|  |         return f"Flow Token {super().__str__()}" | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |  | ||||||
|  |         verbose_name = _("Flow Token") | ||||||
|  |         verbose_name_plural = _("Flow Tokens") | ||||||
|  | |||||||
| @ -24,6 +24,9 @@ PLAN_CONTEXT_SSO = "is_sso" | |||||||
| PLAN_CONTEXT_REDIRECT = "redirect" | PLAN_CONTEXT_REDIRECT = "redirect" | ||||||
| PLAN_CONTEXT_APPLICATION = "application" | PLAN_CONTEXT_APPLICATION = "application" | ||||||
| PLAN_CONTEXT_SOURCE = "source" | PLAN_CONTEXT_SOURCE = "source" | ||||||
|  | # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan | ||||||
|  | # was restored. | ||||||
|  | PLAN_CONTEXT_IS_RESTORED = "is_restored" | ||||||
| GAUGE_FLOWS_CACHED = UpdatingGauge( | GAUGE_FLOWS_CACHED = UpdatingGauge( | ||||||
|     "authentik_flows_cached", |     "authentik_flows_cached", | ||||||
|     "Cached flows", |     "Cached flows", | ||||||
| @ -123,7 +126,9 @@ class FlowPlanner: | |||||||
|     ) -> FlowPlan: |     ) -> FlowPlan: | ||||||
|         """Check each of the flows' policies, check policies for each stage with PolicyBinding |         """Check each of the flows' policies, check policies for each stage with PolicyBinding | ||||||
|         and return ordered list""" |         and return ordered list""" | ||||||
|         with Hub.current.start_span(op="flow.planner.plan") as span: |         with Hub.current.start_span( | ||||||
|  |             op="authentik.flow.planner.plan", description=self.flow.slug | ||||||
|  |         ) as span: | ||||||
|             span: Span |             span: Span | ||||||
|             span.set_data("flow", self.flow) |             span.set_data("flow", self.flow) | ||||||
|             span.set_data("request", request) |             span.set_data("request", request) | ||||||
| @ -178,7 +183,8 @@ class FlowPlanner: | |||||||
|         """Build flow plan by checking each stage in their respective |         """Build flow plan by checking each stage in their respective | ||||||
|         order and checking the applied policies""" |         order and checking the applied policies""" | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
|             op="flow.planner.build_plan" |             op="authentik.flow.planner.build_plan", | ||||||
|  |             description=self.flow.slug, | ||||||
|         ) as span, HIST_FLOWS_PLAN_TIME.labels(flow_slug=self.flow.slug).time(): |         ) as span, HIST_FLOWS_PLAN_TIME.labels(flow_slug=self.flow.slug).time(): | ||||||
|             span: Span |             span: Span | ||||||
|             span.set_data("flow", self.flow) |             span.set_data("flow", self.flow) | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ from django.http.response import HttpResponse | |||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.views.generic.base import View | from django.views.generic.base import View | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
|  | from sentry_sdk.hub import Hub | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import DEFAULT_AVATAR, User | from authentik.core.models import DEFAULT_AVATAR, User | ||||||
| @ -94,8 +95,16 @@ class ChallengeStageView(StageView): | |||||||
|                     keep_context=keep_context, |                     keep_context=keep_context, | ||||||
|                 ) |                 ) | ||||||
|                 return self.executor.restart_flow(keep_context) |                 return self.executor.restart_flow(keep_context) | ||||||
|             return self.challenge_invalid(challenge) |             with Hub.current.start_span( | ||||||
|         return self.challenge_valid(challenge) |                 op="authentik.flow.stage.challenge_invalid", | ||||||
|  |                 description=self.__class__.__name__, | ||||||
|  |             ): | ||||||
|  |                 return self.challenge_invalid(challenge) | ||||||
|  |         with Hub.current.start_span( | ||||||
|  |             op="authentik.flow.stage.challenge_valid", | ||||||
|  |             description=self.__class__.__name__, | ||||||
|  |         ): | ||||||
|  |             return self.challenge_valid(challenge) | ||||||
|  |  | ||||||
|     def format_title(self) -> str: |     def format_title(self) -> str: | ||||||
|         """Allow usage of placeholder in flow title.""" |         """Allow usage of placeholder in flow title.""" | ||||||
| @ -104,7 +113,11 @@ class ChallengeStageView(StageView): | |||||||
|         } |         } | ||||||
|  |  | ||||||
|     def _get_challenge(self, *args, **kwargs) -> Challenge: |     def _get_challenge(self, *args, **kwargs) -> Challenge: | ||||||
|         challenge = self.get_challenge(*args, **kwargs) |         with Hub.current.start_span( | ||||||
|  |             op="authentik.flow.stage.get_challenge", | ||||||
|  |             description=self.__class__.__name__, | ||||||
|  |         ): | ||||||
|  |             challenge = self.get_challenge(*args, **kwargs) | ||||||
|         if "flow_info" not in challenge.initial_data: |         if "flow_info" not in challenge.initial_data: | ||||||
|             flow_info = ContextualFlowInfo( |             flow_info = ContextualFlowInfo( | ||||||
|                 data={ |                 data={ | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ class TestFlowsAPI(APITestCase): | |||||||
|  |  | ||||||
|     def test_models(self): |     def test_models(self): | ||||||
|         """Test that ui_user_settings returns none""" |         """Test that ui_user_settings returns none""" | ||||||
|         self.assertIsNone(Stage().ui_user_settings) |         self.assertIsNone(Stage().ui_user_settings()) | ||||||
|  |  | ||||||
|     def test_api_serializer(self): |     def test_api_serializer(self): | ||||||
|         """Test that stage serializer returns the correct type""" |         """Test that stage serializer returns the correct type""" | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ def model_tester_factory(test_model: Type[Stage]) -> Callable: | |||||||
|             model_class = test_model() |             model_class = test_model() | ||||||
|         self.assertTrue(issubclass(model_class.type, StageView)) |         self.assertTrue(issubclass(model_class.type, StageView)) | ||||||
|         self.assertIsNotNone(test_model.component) |         self.assertIsNotNone(test_model.component) | ||||||
|         _ = model_class.ui_user_settings |         _ = model_class.ui_user_settings() | ||||||
|  |  | ||||||
|     return tester |     return tester | ||||||
|  |  | ||||||
|  | |||||||
| @ -19,6 +19,8 @@ from drf_spectacular.utils import OpenApiParameter, PolymorphicProxySerializer, | |||||||
| from rest_framework.permissions import AllowAny | from rest_framework.permissions import AllowAny | ||||||
| from rest_framework.views import APIView | from rest_framework.views import APIView | ||||||
| from sentry_sdk import capture_exception | from sentry_sdk import capture_exception | ||||||
|  | from sentry_sdk.api import set_tag | ||||||
|  | from sentry_sdk.hub import Hub | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import USER_ATTRIBUTE_DEBUG | from authentik.core.models import USER_ATTRIBUTE_DEBUG | ||||||
| @ -34,8 +36,16 @@ from authentik.flows.challenge import ( | |||||||
|     WithUserInfoChallenge, |     WithUserInfoChallenge, | ||||||
| ) | ) | ||||||
| from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException | from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException | ||||||
| from authentik.flows.models import ConfigurableStage, Flow, FlowDesignation, FlowStageBinding, Stage | from authentik.flows.models import ( | ||||||
|  |     ConfigurableStage, | ||||||
|  |     Flow, | ||||||
|  |     FlowDesignation, | ||||||
|  |     FlowStageBinding, | ||||||
|  |     FlowToken, | ||||||
|  |     Stage, | ||||||
|  | ) | ||||||
| from authentik.flows.planner import ( | from authentik.flows.planner import ( | ||||||
|  |     PLAN_CONTEXT_IS_RESTORED, | ||||||
|     PLAN_CONTEXT_PENDING_USER, |     PLAN_CONTEXT_PENDING_USER, | ||||||
|     PLAN_CONTEXT_REDIRECT, |     PLAN_CONTEXT_REDIRECT, | ||||||
|     FlowPlan, |     FlowPlan, | ||||||
| @ -55,6 +65,7 @@ SESSION_KEY_APPLICATION_PRE = "authentik_flows_application_pre" | |||||||
| SESSION_KEY_GET = "authentik_flows_get" | SESSION_KEY_GET = "authentik_flows_get" | ||||||
| SESSION_KEY_POST = "authentik_flows_post" | SESSION_KEY_POST = "authentik_flows_post" | ||||||
| SESSION_KEY_HISTORY = "authentik_flows_history" | SESSION_KEY_HISTORY = "authentik_flows_history" | ||||||
|  | QS_KEY_TOKEN = "flow_token"  # nosec | ||||||
|  |  | ||||||
|  |  | ||||||
| def challenge_types(): | def challenge_types(): | ||||||
| @ -117,6 +128,7 @@ class FlowExecutorView(APIView): | |||||||
|         super().setup(request, flow_slug=flow_slug) |         super().setup(request, flow_slug=flow_slug) | ||||||
|         self.flow = get_object_or_404(Flow.objects.select_related(), slug=flow_slug) |         self.flow = get_object_or_404(Flow.objects.select_related(), slug=flow_slug) | ||||||
|         self._logger = get_logger().bind(flow_slug=flow_slug) |         self._logger = get_logger().bind(flow_slug=flow_slug) | ||||||
|  |         set_tag("authentik.flow", self.flow.slug) | ||||||
|  |  | ||||||
|     def handle_invalid_flow(self, exc: BaseException) -> HttpResponse: |     def handle_invalid_flow(self, exc: BaseException) -> HttpResponse: | ||||||
|         """When a flow is non-applicable check if user is on the correct domain""" |         """When a flow is non-applicable check if user is on the correct domain""" | ||||||
| @ -127,71 +139,100 @@ class FlowExecutorView(APIView): | |||||||
|         message = exc.__doc__ if exc.__doc__ else str(exc) |         message = exc.__doc__ if exc.__doc__ else str(exc) | ||||||
|         return self.stage_invalid(error_message=message) |         return self.stage_invalid(error_message=message) | ||||||
|  |  | ||||||
|  |     def _check_flow_token(self, get_params: QueryDict): | ||||||
|  |         """Check if the user is using a flow token to restore a plan""" | ||||||
|  |         tokens = FlowToken.filter_not_expired(key=get_params[QS_KEY_TOKEN]) | ||||||
|  |         if not tokens.exists(): | ||||||
|  |             return False | ||||||
|  |         token: FlowToken = tokens.first() | ||||||
|  |         try: | ||||||
|  |             plan = token.plan | ||||||
|  |         except (AttributeError, EOFError, ImportError, IndexError) as exc: | ||||||
|  |             LOGGER.warning("f(exec): Failed to restore token plan", exc=exc) | ||||||
|  |         finally: | ||||||
|  |             token.delete() | ||||||
|  |         if not isinstance(plan, FlowPlan): | ||||||
|  |             return None | ||||||
|  |         plan.context[PLAN_CONTEXT_IS_RESTORED] = True | ||||||
|  |         self._logger.debug("f(exec): restored flow plan from token", plan=plan) | ||||||
|  |         return plan | ||||||
|  |  | ||||||
|     # pylint: disable=unused-argument, too-many-return-statements |     # pylint: disable=unused-argument, too-many-return-statements | ||||||
|     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: |     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: | ||||||
|         # Early check if there's an active Plan for the current session |         with Hub.current.start_span( | ||||||
|         if SESSION_KEY_PLAN in self.request.session: |             op="authentik.flow.executor.dispatch", description=self.flow.slug | ||||||
|             self.plan = self.request.session[SESSION_KEY_PLAN] |         ) as span: | ||||||
|             if self.plan.flow_pk != self.flow.pk.hex: |             span.set_data("authentik Flow", self.flow.slug) | ||||||
|                 self._logger.warning( |             get_params = QueryDict(request.GET.get("query", "")) | ||||||
|                     "f(exec): Found existing plan for other flow, deleting plan", |             if QS_KEY_TOKEN in get_params: | ||||||
|                 ) |                 plan = self._check_flow_token(get_params) | ||||||
|                 # Existing plan is deleted from session and instance |                 if plan: | ||||||
|                 self.plan = None |                     self.request.session[SESSION_KEY_PLAN] = plan | ||||||
|                 self.cancel() |             # Early check if there's an active Plan for the current session | ||||||
|             self._logger.debug("f(exec): Continuing existing plan") |             if SESSION_KEY_PLAN in self.request.session: | ||||||
|  |                 self.plan = self.request.session[SESSION_KEY_PLAN] | ||||||
|  |                 if self.plan.flow_pk != self.flow.pk.hex: | ||||||
|  |                     self._logger.warning( | ||||||
|  |                         "f(exec): Found existing plan for other flow, deleting plan", | ||||||
|  |                     ) | ||||||
|  |                     # Existing plan is deleted from session and instance | ||||||
|  |                     self.plan = None | ||||||
|  |                     self.cancel() | ||||||
|  |                 self._logger.debug("f(exec): Continuing existing plan") | ||||||
|  |  | ||||||
|         # Don't check session again as we've either already loaded the plan or we need to plan |             # Don't check session again as we've either already loaded the plan or we need to plan | ||||||
|         if not self.plan: |             if not self.plan: | ||||||
|             request.session[SESSION_KEY_HISTORY] = [] |                 request.session[SESSION_KEY_HISTORY] = [] | ||||||
|             self._logger.debug("f(exec): No active Plan found, initiating planner") |                 self._logger.debug("f(exec): No active Plan found, initiating planner") | ||||||
|  |                 try: | ||||||
|  |                     self.plan = self._initiate_plan() | ||||||
|  |                 except FlowNonApplicableException as exc: | ||||||
|  |                     self._logger.warning("f(exec): Flow not applicable to current user", exc=exc) | ||||||
|  |                     return to_stage_response(self.request, self.handle_invalid_flow(exc)) | ||||||
|  |                 except EmptyFlowException as exc: | ||||||
|  |                     self._logger.warning("f(exec): Flow is empty", exc=exc) | ||||||
|  |                     # To match behaviour with loading an empty flow plan from cache, | ||||||
|  |                     # we don't show an error message here, but rather call _flow_done() | ||||||
|  |                     return self._flow_done() | ||||||
|  |             # Initial flow request, check if we have an upstream query string passed in | ||||||
|  |             request.session[SESSION_KEY_GET] = get_params | ||||||
|  |             # We don't save the Plan after getting the next stage | ||||||
|  |             # as it hasn't been successfully passed yet | ||||||
|             try: |             try: | ||||||
|                 self.plan = self._initiate_plan() |                 # This is the first time we actually access any attribute on the selected plan | ||||||
|             except FlowNonApplicableException as exc: |                 # if the cached plan is from an older version, it might have different attributes | ||||||
|                 self._logger.warning("f(exec): Flow not applicable to current user", exc=exc) |                 # in which case we just delete the plan and invalidate everything | ||||||
|                 return to_stage_response(self.request, self.handle_invalid_flow(exc)) |                 next_binding = self.plan.next(self.request) | ||||||
|             except EmptyFlowException as exc: |             except Exception as exc:  # pylint: disable=broad-except | ||||||
|                 self._logger.warning("f(exec): Flow is empty", exc=exc) |                 self._logger.warning( | ||||||
|                 # To match behaviour with loading an empty flow plan from cache, |                     "f(exec): found incompatible flow plan, invalidating run", exc=exc | ||||||
|                 # we don't show an error message here, but rather call _flow_done() |                 ) | ||||||
|  |                 keys = cache.keys("flow_*") | ||||||
|  |                 cache.delete_many(keys) | ||||||
|  |                 return self.stage_invalid() | ||||||
|  |             if not next_binding: | ||||||
|  |                 self._logger.debug("f(exec): no more stages, flow is done.") | ||||||
|                 return self._flow_done() |                 return self._flow_done() | ||||||
|         # Initial flow request, check if we have an upstream query string passed in |             self.current_binding = next_binding | ||||||
|         request.session[SESSION_KEY_GET] = QueryDict(request.GET.get("query", "")) |             self.current_stage = next_binding.stage | ||||||
|         # We don't save the Plan after getting the next stage |             self._logger.debug( | ||||||
|         # as it hasn't been successfully passed yet |                 "f(exec): Current stage", | ||||||
|         try: |                 current_stage=self.current_stage, | ||||||
|             # This is the first time we actually access any attribute on the selected plan |                 flow_slug=self.flow.slug, | ||||||
|             # if the cached plan is from an older version, it might have different attributes |             ) | ||||||
|             # in which case we just delete the plan and invalidate everything |             try: | ||||||
|             next_binding = self.plan.next(self.request) |                 stage_cls = self.current_stage.type | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |             except NotImplementedError as exc: | ||||||
|             self._logger.warning("f(exec): found incompatible flow plan, invalidating run", exc=exc) |                 self._logger.debug("Error getting stage type", exc=exc) | ||||||
|             keys = cache.keys("flow_*") |                 return self.stage_invalid() | ||||||
|             cache.delete_many(keys) |             self.current_stage_view = stage_cls(self) | ||||||
|             return self.stage_invalid() |             self.current_stage_view.args = self.args | ||||||
|         if not next_binding: |             self.current_stage_view.kwargs = self.kwargs | ||||||
|             self._logger.debug("f(exec): no more stages, flow is done.") |             self.current_stage_view.request = request | ||||||
|             return self._flow_done() |             try: | ||||||
|         self.current_binding = next_binding |                 return super().dispatch(request) | ||||||
|         self.current_stage = next_binding.stage |             except InvalidStageError as exc: | ||||||
|         self._logger.debug( |                 return self.stage_invalid(str(exc)) | ||||||
|             "f(exec): Current stage", |  | ||||||
|             current_stage=self.current_stage, |  | ||||||
|             flow_slug=self.flow.slug, |  | ||||||
|         ) |  | ||||||
|         try: |  | ||||||
|             stage_cls = self.current_stage.type |  | ||||||
|         except NotImplementedError as exc: |  | ||||||
|             self._logger.debug("Error getting stage type", exc=exc) |  | ||||||
|             return self.stage_invalid() |  | ||||||
|         self.current_stage_view = stage_cls(self) |  | ||||||
|         self.current_stage_view.args = self.args |  | ||||||
|         self.current_stage_view.kwargs = self.kwargs |  | ||||||
|         self.current_stage_view.request = request |  | ||||||
|         try: |  | ||||||
|             return super().dispatch(request) |  | ||||||
|         except InvalidStageError as exc: |  | ||||||
|             return self.stage_invalid(str(exc)) |  | ||||||
|  |  | ||||||
|     def handle_exception(self, exc: Exception) -> HttpResponse: |     def handle_exception(self, exc: Exception) -> HttpResponse: | ||||||
|         """Handle exception in stage execution""" |         """Handle exception in stage execution""" | ||||||
| @ -233,8 +274,15 @@ class FlowExecutorView(APIView): | |||||||
|             stage=self.current_stage, |             stage=self.current_stage, | ||||||
|         ) |         ) | ||||||
|         try: |         try: | ||||||
|             stage_response = self.current_stage_view.get(request, *args, **kwargs) |             with Hub.current.start_span( | ||||||
|             return to_stage_response(request, stage_response) |                 op="authentik.flow.executor.stage", | ||||||
|  |                 description=class_to_path(self.current_stage_view.__class__), | ||||||
|  |             ) as span: | ||||||
|  |                 span.set_data("Method", "GET") | ||||||
|  |                 span.set_data("authentik Stage", self.current_stage_view) | ||||||
|  |                 span.set_data("authentik Flow", self.flow.slug) | ||||||
|  |                 stage_response = self.current_stage_view.get(request, *args, **kwargs) | ||||||
|  |                 return to_stage_response(request, stage_response) | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |         except Exception as exc:  # pylint: disable=broad-except | ||||||
|             return self.handle_exception(exc) |             return self.handle_exception(exc) | ||||||
|  |  | ||||||
| @ -270,8 +318,15 @@ class FlowExecutorView(APIView): | |||||||
|             stage=self.current_stage, |             stage=self.current_stage, | ||||||
|         ) |         ) | ||||||
|         try: |         try: | ||||||
|             stage_response = self.current_stage_view.post(request, *args, **kwargs) |             with Hub.current.start_span( | ||||||
|             return to_stage_response(request, stage_response) |                 op="authentik.flow.executor.stage", | ||||||
|  |                 description=class_to_path(self.current_stage_view.__class__), | ||||||
|  |             ) as span: | ||||||
|  |                 span.set_data("Method", "POST") | ||||||
|  |                 span.set_data("authentik Stage", self.current_stage_view) | ||||||
|  |                 span.set_data("authentik Flow", self.flow.slug) | ||||||
|  |                 stage_response = self.current_stage_view.post(request, *args, **kwargs) | ||||||
|  |                 return to_stage_response(request, stage_response) | ||||||
|         except Exception as exc:  # pylint: disable=broad-except |         except Exception as exc:  # pylint: disable=broad-except | ||||||
|             return self.handle_exception(exc) |             return self.handle_exception(exc) | ||||||
|  |  | ||||||
| @ -316,6 +371,12 @@ class FlowExecutorView(APIView): | |||||||
|             NEXT_ARG_NAME, "authentik_core:root-redirect" |             NEXT_ARG_NAME, "authentik_core:root-redirect" | ||||||
|         ) |         ) | ||||||
|         self.cancel() |         self.cancel() | ||||||
|  |         Event.new( | ||||||
|  |             action=EventAction.FLOW_EXECUTION, | ||||||
|  |             flow=self.flow, | ||||||
|  |             designation=self.flow.designation, | ||||||
|  |             successful=True, | ||||||
|  |         ).from_http(self.request) | ||||||
|         return to_stage_response(self.request, redirect_with_qs(next_param)) |         return to_stage_response(self.request, redirect_with_qs(next_param)) | ||||||
|  |  | ||||||
|     def stage_ok(self) -> HttpResponse: |     def stage_ok(self) -> HttpResponse: | ||||||
|  | |||||||
| @ -87,9 +87,7 @@ class FlowInspectorView(APIView): | |||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         responses={ |         responses={ | ||||||
|             200: FlowInspectionSerializer(), |             200: FlowInspectionSerializer(), | ||||||
|             400: OpenApiResponse( |             400: OpenApiResponse(description="No flow plan in session."), | ||||||
|                 description="No flow plan in session." |  | ||||||
|             ),  # This error can be raised by the email stage |  | ||||||
|         }, |         }, | ||||||
|         request=OpenApiTypes.NONE, |         request=OpenApiTypes.NONE, | ||||||
|         operation_id="flows_inspector_get", |         operation_id="flows_inspector_get", | ||||||
| @ -106,7 +104,10 @@ class FlowInspectorView(APIView): | |||||||
|         if SESSION_KEY_PLAN in request.session: |         if SESSION_KEY_PLAN in request.session: | ||||||
|             current_plan: FlowPlan = request.session[SESSION_KEY_PLAN] |             current_plan: FlowPlan = request.session[SESSION_KEY_PLAN] | ||||||
|         else: |         else: | ||||||
|             current_plan = request.session[SESSION_KEY_HISTORY][-1] |             try: | ||||||
|  |                 current_plan = request.session[SESSION_KEY_HISTORY][-1] | ||||||
|  |             except IndexError: | ||||||
|  |                 return Response(status=400) | ||||||
|             is_completed = True |             is_completed = True | ||||||
|         current_serializer = FlowInspectorPlanSerializer( |         current_serializer = FlowInspectorPlanSerializer( | ||||||
|             instance=current_plan, context={"request": request} |             instance=current_plan, context={"request": request} | ||||||
|  | |||||||
| @ -20,7 +20,6 @@ web: | |||||||
|   listen: 0.0.0.0:9000 |   listen: 0.0.0.0:9000 | ||||||
|   listen_tls: 0.0.0.0:9443 |   listen_tls: 0.0.0.0:9443 | ||||||
|   listen_metrics: 0.0.0.0:9300 |   listen_metrics: 0.0.0.0:9300 | ||||||
|   load_local_files: false |  | ||||||
|   outpost_port_offset: 0 |   outpost_port_offset: 0 | ||||||
|  |  | ||||||
| redis: | redis: | ||||||
|  | |||||||
| @ -80,8 +80,9 @@ class BaseEvaluator: | |||||||
|         """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. |         """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. | ||||||
|         If any exception is raised during execution, it is raised. |         If any exception is raised during execution, it is raised. | ||||||
|         The result is returned without any type-checking.""" |         The result is returned without any type-checking.""" | ||||||
|         with Hub.current.start_span(op="lib.evaluator.evaluate") as span: |         with Hub.current.start_span(op="authentik.lib.evaluator.evaluate") as span: | ||||||
|             span: Span |             span: Span | ||||||
|  |             span.description = self._filename | ||||||
|             span.set_data("expression", expression_source) |             span.set_data("expression", expression_source) | ||||||
|             param_keys = self._context.keys() |             param_keys = self._context.keys() | ||||||
|             try: |             try: | ||||||
|  | |||||||
| @ -4,6 +4,7 @@ from typing import Any, Optional | |||||||
|  |  | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from requests.sessions import Session | from requests.sessions import Session | ||||||
|  | from sentry_sdk.hub import Hub | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik import ENV_GIT_HASH_KEY, __version__ | from authentik import ENV_GIT_HASH_KEY, __version__ | ||||||
| @ -52,6 +53,12 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]: | |||||||
|             fake_ip=fake_ip, |             fake_ip=fake_ip, | ||||||
|         ) |         ) | ||||||
|         return None |         return None | ||||||
|  |     # Update sentry scope to include correct IP | ||||||
|  |     user = Hub.current.scope._user | ||||||
|  |     if not user: | ||||||
|  |         user = {} | ||||||
|  |     user["ip_address"] = fake_ip | ||||||
|  |     Hub.current.scope.set_user(user) | ||||||
|     return fake_ip |     return fake_ip | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,12 +2,18 @@ | |||||||
| from django.db import DatabaseError | from django.db import DatabaseError | ||||||
|  |  | ||||||
| from authentik.core.tasks import CELERY_APP | from authentik.core.tasks import CELERY_APP | ||||||
| from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus | from authentik.events.monitored_tasks import ( | ||||||
|  |     MonitoredTask, | ||||||
|  |     TaskResult, | ||||||
|  |     TaskResultStatus, | ||||||
|  |     prefill_task, | ||||||
|  | ) | ||||||
| from authentik.managed.manager import ObjectManager | from authentik.managed.manager import ObjectManager | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def managed_reconcile(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def managed_reconcile(self: MonitoredTask): | ||||||
|     """Run ObjectManager to ensure objects are up-to-date""" |     """Run ObjectManager to ensure objects are up-to-date""" | ||||||
|     try: |     try: | ||||||
|         ObjectManager().run() |         ObjectManager().run() | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| """Outpost API Views""" | """Outpost API Views""" | ||||||
| from dacite.core import from_dict | from dacite.core import from_dict | ||||||
| from dacite.exceptions import DaciteError | from dacite.exceptions import DaciteError | ||||||
|  | from django_filters.filters import ModelMultipleChoiceFilter | ||||||
|  | from django_filters.filterset import FilterSet | ||||||
| from drf_spectacular.utils import extend_schema | from drf_spectacular.utils import extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import BooleanField, CharField, DateTimeField | from rest_framework.fields import BooleanField, CharField, DateTimeField | ||||||
| @ -99,16 +101,30 @@ class OutpostHealthSerializer(PassiveSerializer): | |||||||
|     version_outdated = BooleanField(read_only=True) |     version_outdated = BooleanField(read_only=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OutpostFilter(FilterSet): | ||||||
|  |     """Filter for Outposts""" | ||||||
|  |  | ||||||
|  |     providers_by_pk = ModelMultipleChoiceFilter( | ||||||
|  |         field_name="providers", | ||||||
|  |         queryset=Provider.objects.all(), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |  | ||||||
|  |         model = Outpost | ||||||
|  |         fields = { | ||||||
|  |             "providers": ["isnull"], | ||||||
|  |             "name": ["iexact", "icontains"], | ||||||
|  |             "service_connection__name": ["iexact", "icontains"], | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutpostViewSet(UsedByMixin, ModelViewSet): | class OutpostViewSet(UsedByMixin, ModelViewSet): | ||||||
|     """Outpost Viewset""" |     """Outpost Viewset""" | ||||||
|  |  | ||||||
|     queryset = Outpost.objects.all() |     queryset = Outpost.objects.all() | ||||||
|     serializer_class = OutpostSerializer |     serializer_class = OutpostSerializer | ||||||
|     filterset_fields = { |     filterset_class = OutpostFilter | ||||||
|         "providers": ["isnull"], |  | ||||||
|         "name": ["iexact", "icontains"], |  | ||||||
|         "service_connection__name": ["iexact", "icontains"], |  | ||||||
|     } |  | ||||||
|     search_fields = [ |     search_fields = [ | ||||||
|         "name", |         "name", | ||||||
|         "providers__name", |         "providers__name", | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ from dacite import from_dict | |||||||
| from dacite.data import Data | from dacite.data import Data | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from prometheus_client import Gauge | from prometheus_client import Gauge | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.core.channels import AuthJsonConsumer | from authentik.core.channels import AuthJsonConsumer | ||||||
| from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState | from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState | ||||||
| @ -23,8 +23,6 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge( | |||||||
|     ["outpost", "uid", "version"], |     ["outpost", "uid", "version"], | ||||||
| ) | ) | ||||||
|  |  | ||||||
| LOGGER = get_logger() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class WebsocketMessageInstruction(IntEnum): | class WebsocketMessageInstruction(IntEnum): | ||||||
|     """Commands which can be triggered over Websocket""" |     """Commands which can be triggered over Websocket""" | ||||||
| @ -51,6 +49,7 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|     """Handler for Outposts that connect over websockets for health checks and live updates""" |     """Handler for Outposts that connect over websockets for health checks and live updates""" | ||||||
|  |  | ||||||
|     outpost: Optional[Outpost] = None |     outpost: Optional[Outpost] = None | ||||||
|  |     logger: BoundLogger | ||||||
|  |  | ||||||
|     last_uid: Optional[str] = None |     last_uid: Optional[str] = None | ||||||
|  |  | ||||||
| @ -59,11 +58,20 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|     def connect(self): |     def connect(self): | ||||||
|         super().connect() |         super().connect() | ||||||
|         uuid = self.scope["url_route"]["kwargs"]["pk"] |         uuid = self.scope["url_route"]["kwargs"]["pk"] | ||||||
|         outpost = get_objects_for_user(self.user, "authentik_outposts.view_outpost").filter(pk=uuid) |         outpost = ( | ||||||
|         if not outpost.exists(): |             get_objects_for_user(self.user, "authentik_outposts.view_outpost") | ||||||
|  |             .filter(pk=uuid) | ||||||
|  |             .first() | ||||||
|  |         ) | ||||||
|  |         if not outpost: | ||||||
|             raise DenyConnection() |             raise DenyConnection() | ||||||
|         self.accept() |         self.logger = get_logger().bind(outpost=outpost) | ||||||
|         self.outpost = outpost.first() |         try: | ||||||
|  |             self.accept() | ||||||
|  |         except RuntimeError as exc: | ||||||
|  |             self.logger.warning("runtime error during accept", exc=exc) | ||||||
|  |             raise DenyConnection() | ||||||
|  |         self.outpost = outpost | ||||||
|         self.last_uid = self.channel_name |         self.last_uid = self.channel_name | ||||||
|  |  | ||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
| @ -78,9 +86,8 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|                 uid=self.last_uid, |                 uid=self.last_uid, | ||||||
|                 expected=self.outpost.config.kubernetes_replicas, |                 expected=self.outpost.config.kubernetes_replicas, | ||||||
|             ).dec() |             ).dec() | ||||||
|         LOGGER.debug( |         self.logger.debug( | ||||||
|             "removed outpost instance from cache", |             "removed outpost instance from cache", | ||||||
|             outpost=self.outpost, |  | ||||||
|             instance_uuid=self.last_uid, |             instance_uuid=self.last_uid, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -103,9 +110,8 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|                 uid=self.last_uid, |                 uid=self.last_uid, | ||||||
|                 expected=self.outpost.config.kubernetes_replicas, |                 expected=self.outpost.config.kubernetes_replicas, | ||||||
|             ).inc() |             ).inc() | ||||||
|             LOGGER.debug( |             self.logger.debug( | ||||||
|                 "added outpost instance to cache", |                 "added outpost instance to cache", | ||||||
|                 outpost=self.outpost, |  | ||||||
|                 instance_uuid=self.last_uid, |                 instance_uuid=self.last_uid, | ||||||
|             ) |             ) | ||||||
|             self.first_msg = True |             self.first_msg = True | ||||||
|  | |||||||
| @ -24,6 +24,8 @@ class DockerController(BaseController): | |||||||
|  |  | ||||||
|     def __init__(self, outpost: Outpost, connection: DockerServiceConnection) -> None: |     def __init__(self, outpost: Outpost, connection: DockerServiceConnection) -> None: | ||||||
|         super().__init__(outpost, connection) |         super().__init__(outpost, connection) | ||||||
|  |         if outpost.managed == MANAGED_OUTPOST: | ||||||
|  |             return | ||||||
|         try: |         try: | ||||||
|             self.client = connection.client() |             self.client = connection.client() | ||||||
|         except ServiceConnectionInvalid as exc: |         except ServiceConnectionInvalid as exc: | ||||||
| @ -108,7 +110,7 @@ class DockerController(BaseController): | |||||||
|         image = self.get_container_image() |         image = self.get_container_image() | ||||||
|         try: |         try: | ||||||
|             self.client.images.pull(image) |             self.client.images.pull(image) | ||||||
|         except DockerException: |         except DockerException:  # pragma: no cover | ||||||
|             image = f"goauthentik.io/{self.outpost.type}:latest" |             image = f"goauthentik.io/{self.outpost.type}:latest" | ||||||
|             self.client.images.pull(image) |             self.client.images.pull(image) | ||||||
|         return image |         return image | ||||||
| @ -142,7 +144,7 @@ class DockerController(BaseController): | |||||||
|                 True, |                 True, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def _migrate_container_name(self): |     def _migrate_container_name(self):  # pragma: no cover | ||||||
|         """Migrate 2021.9 to 2021.10+""" |         """Migrate 2021.9 to 2021.10+""" | ||||||
|         old_name = f"authentik-proxy-{self.outpost.uuid.hex}" |         old_name = f"authentik-proxy-{self.outpost.uuid.hex}" | ||||||
|         try: |         try: | ||||||
| @ -225,12 +227,14 @@ class DockerController(BaseController): | |||||||
|             raise ControllerException(str(exc)) from exc |             raise ControllerException(str(exc)) from exc | ||||||
|  |  | ||||||
|     def down(self): |     def down(self): | ||||||
|         if self.outpost.managed != MANAGED_OUTPOST: |         if self.outpost.managed == MANAGED_OUTPOST: | ||||||
|             return |             return | ||||||
|         try: |         try: | ||||||
|             container, _ = self._get_container() |             container, _ = self._get_container() | ||||||
|             if container.status == "running": |             if container.status == "running": | ||||||
|  |                 self.logger.info("Stopping container.") | ||||||
|                 container.kill() |                 container.kill() | ||||||
|  |             self.logger.info("Removing container.") | ||||||
|             container.remove(force=True) |             container.remove(force=True) | ||||||
|         except DockerException as exc: |         except DockerException as exc: | ||||||
|             raise ControllerException(str(exc)) from exc |             raise ControllerException(str(exc)) from exc | ||||||
|  | |||||||
| @ -20,6 +20,11 @@ if TYPE_CHECKING: | |||||||
| T = TypeVar("T", V1Pod, V1Deployment) | T = TypeVar("T", V1Pod, V1Deployment) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_version() -> str: | ||||||
|  |     """Wrapper for __version__ to make testing easier""" | ||||||
|  |     return __version__ | ||||||
|  |  | ||||||
|  |  | ||||||
| class KubernetesObjectReconciler(Generic[T]): | class KubernetesObjectReconciler(Generic[T]): | ||||||
|     """Base Kubernetes Reconciler, handles the basic logic.""" |     """Base Kubernetes Reconciler, handles the basic logic.""" | ||||||
|  |  | ||||||
| @ -146,13 +151,13 @@ class KubernetesObjectReconciler(Generic[T]): | |||||||
|         return V1ObjectMeta( |         return V1ObjectMeta( | ||||||
|             namespace=self.namespace, |             namespace=self.namespace, | ||||||
|             labels={ |             labels={ | ||||||
|                 "app.kubernetes.io/name": f"authentik-{self.controller.outpost.type.lower()}", |  | ||||||
|                 "app.kubernetes.io/instance": slugify(self.controller.outpost.name), |                 "app.kubernetes.io/instance": slugify(self.controller.outpost.name), | ||||||
|                 "app.kubernetes.io/version": __version__, |  | ||||||
|                 "app.kubernetes.io/managed-by": "goauthentik.io", |                 "app.kubernetes.io/managed-by": "goauthentik.io", | ||||||
|                 "goauthentik.io/outpost-uuid": self.controller.outpost.uuid.hex, |                 "app.kubernetes.io/name": f"authentik-{self.controller.outpost.type.lower()}", | ||||||
|                 "goauthentik.io/outpost-type": str(self.controller.outpost.type), |                 "app.kubernetes.io/version": get_version(), | ||||||
|                 "goauthentik.io/outpost-name": slugify(self.controller.outpost.name), |                 "goauthentik.io/outpost-name": slugify(self.controller.outpost.name), | ||||||
|  |                 "goauthentik.io/outpost-type": str(self.controller.outpost.type), | ||||||
|  |                 "goauthentik.io/outpost-uuid": self.controller.outpost.uuid.hex, | ||||||
|             }, |             }, | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -401,6 +401,7 @@ class Outpost(ManagedModel): | |||||||
|             user = users.first() |             user = users.first() | ||||||
|         user.attributes[USER_ATTRIBUTE_SA] = True |         user.attributes[USER_ATTRIBUTE_SA] = True | ||||||
|         user.attributes[USER_ATTRIBUTE_CAN_OVERRIDE_IP] = True |         user.attributes[USER_ATTRIBUTE_CAN_OVERRIDE_IP] = True | ||||||
|  |         user.name = f"Outpost {self.name} Service-Account" | ||||||
|         user.save() |         user.save() | ||||||
|         if should_create_user: |         if should_create_user: | ||||||
|             self.build_user_permissions(user) |             self.build_user_permissions(user) | ||||||
|  | |||||||
| @ -19,9 +19,9 @@ from structlog.stdlib import get_logger | |||||||
|  |  | ||||||
| from authentik.events.monitored_tasks import ( | from authentik.events.monitored_tasks import ( | ||||||
|     MonitoredTask, |     MonitoredTask, | ||||||
|     PrefilledMonitoredTask, |  | ||||||
|     TaskResult, |     TaskResult, | ||||||
|     TaskResultStatus, |     TaskResultStatus, | ||||||
|  |     prefill_task, | ||||||
| ) | ) | ||||||
| from authentik.lib.utils.reflection import path_to_class | from authentik.lib.utils.reflection import path_to_class | ||||||
| from authentik.outposts.controllers.base import BaseController, ControllerException | from authentik.outposts.controllers.base import BaseController, ControllerException | ||||||
| @ -75,8 +75,9 @@ def outpost_service_connection_state(connection_pk: Any): | |||||||
|     cache.set(connection.state_key, state, timeout=None) |     cache.set(connection.state_key, state, timeout=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def outpost_service_connection_monitor(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def outpost_service_connection_monitor(self: MonitoredTask): | ||||||
|     """Regularly check the state of Outpost Service Connections""" |     """Regularly check the state of Outpost Service Connections""" | ||||||
|     connections = OutpostServiceConnection.objects.all() |     connections = OutpostServiceConnection.objects.all() | ||||||
|     for connection in connections.iterator(): |     for connection in connections.iterator(): | ||||||
| @ -104,9 +105,12 @@ def outpost_controller( | |||||||
|     logs = [] |     logs = [] | ||||||
|     if from_cache: |     if from_cache: | ||||||
|         outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) |         outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) | ||||||
|  |         LOGGER.debug("Getting outpost from cache to delete") | ||||||
|     else: |     else: | ||||||
|         outpost: Outpost = Outpost.objects.filter(pk=outpost_pk).first() |         outpost: Outpost = Outpost.objects.filter(pk=outpost_pk).first() | ||||||
|  |         LOGGER.debug("Getting outpost from DB") | ||||||
|     if not outpost: |     if not outpost: | ||||||
|  |         LOGGER.warning("No outpost") | ||||||
|         return |         return | ||||||
|     self.set_uid(slugify(outpost.name)) |     self.set_uid(slugify(outpost.name)) | ||||||
|     try: |     try: | ||||||
| @ -124,8 +128,9 @@ def outpost_controller( | |||||||
|         self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, logs)) |         self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, logs)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def outpost_token_ensurer(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def outpost_token_ensurer(self: MonitoredTask): | ||||||
|     """Periodically ensure that all Outposts have valid Service Accounts |     """Periodically ensure that all Outposts have valid Service Accounts | ||||||
|     and Tokens""" |     and Tokens""" | ||||||
|     all_outposts = Outpost.objects.all() |     all_outposts = Outpost.objects.all() | ||||||
|  | |||||||
							
								
								
									
										124
									
								
								authentik/outposts/tests/test_controller_docker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								authentik/outposts/tests/test_controller_docker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,124 @@ | |||||||
|  | """Docker controller tests""" | ||||||
|  | from django.test import TestCase | ||||||
|  | from docker.models.containers import Container | ||||||
|  |  | ||||||
|  | from authentik.managed.manager import ObjectManager | ||||||
|  | from authentik.outposts.controllers.base import ControllerException | ||||||
|  | from authentik.outposts.controllers.docker import DockerController | ||||||
|  | from authentik.outposts.managed import MANAGED_OUTPOST | ||||||
|  | from authentik.outposts.models import DockerServiceConnection, Outpost, OutpostType | ||||||
|  | from authentik.providers.proxy.controllers.docker import ProxyDockerController | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DockerControllerTests(TestCase): | ||||||
|  |     """Docker controller tests""" | ||||||
|  |  | ||||||
|  |     def setUp(self) -> None: | ||||||
|  |         self.outpost = Outpost.objects.create( | ||||||
|  |             name="test", | ||||||
|  |             type=OutpostType.PROXY, | ||||||
|  |         ) | ||||||
|  |         self.integration = DockerServiceConnection(name="test") | ||||||
|  |         ObjectManager().run() | ||||||
|  |  | ||||||
|  |     def test_init_managed(self): | ||||||
|  |         """Docker controller shouldn't do anything for managed outpost""" | ||||||
|  |         controller = DockerController( | ||||||
|  |             Outpost.objects.filter(managed=MANAGED_OUTPOST).first(), self.integration | ||||||
|  |         ) | ||||||
|  |         self.assertIsNone(controller.up()) | ||||||
|  |         self.assertIsNone(controller.down()) | ||||||
|  |  | ||||||
|  |     def test_init_invalid(self): | ||||||
|  |         """Ensure init fails with invalid client""" | ||||||
|  |         with self.assertRaises(ControllerException): | ||||||
|  |             DockerController(self.outpost, self.integration) | ||||||
|  |  | ||||||
|  |     def test_env_valid(self): | ||||||
|  |         """Test environment check""" | ||||||
|  |         controller = DockerController( | ||||||
|  |             Outpost.objects.filter(managed=MANAGED_OUTPOST).first(), self.integration | ||||||
|  |         ) | ||||||
|  |         env = [f"{key}={value}" for key, value in controller._get_env().items()] | ||||||
|  |         container = Container(attrs={"Config": {"Env": env}}) | ||||||
|  |         self.assertFalse(controller._comp_env(container)) | ||||||
|  |  | ||||||
|  |     def test_env_invalid(self): | ||||||
|  |         """Test environment check""" | ||||||
|  |         controller = DockerController( | ||||||
|  |             Outpost.objects.filter(managed=MANAGED_OUTPOST).first(), self.integration | ||||||
|  |         ) | ||||||
|  |         container = Container(attrs={"Config": {"Env": []}}) | ||||||
|  |         self.assertTrue(controller._comp_env(container)) | ||||||
|  |  | ||||||
|  |     def test_label_valid(self): | ||||||
|  |         """Test label check""" | ||||||
|  |         controller = DockerController( | ||||||
|  |             Outpost.objects.filter(managed=MANAGED_OUTPOST).first(), self.integration | ||||||
|  |         ) | ||||||
|  |         container = Container(attrs={"Config": {"Labels": controller._get_labels()}}) | ||||||
|  |         self.assertFalse(controller._comp_labels(container)) | ||||||
|  |  | ||||||
|  |     def test_label_invalid(self): | ||||||
|  |         """Test label check""" | ||||||
|  |         controller = DockerController( | ||||||
|  |             Outpost.objects.filter(managed=MANAGED_OUTPOST).first(), self.integration | ||||||
|  |         ) | ||||||
|  |         container = Container(attrs={"Config": {"Labels": {}}}) | ||||||
|  |         self.assertTrue(controller._comp_labels(container)) | ||||||
|  |         container = Container(attrs={"Config": {"Labels": {"io.goauthentik.outpost-uuid": "foo"}}}) | ||||||
|  |         self.assertTrue(controller._comp_labels(container)) | ||||||
|  |  | ||||||
|  |     def test_port_valid(self): | ||||||
|  |         """Test port check""" | ||||||
|  |         controller = ProxyDockerController( | ||||||
|  |             Outpost.objects.filter(managed=MANAGED_OUTPOST).first(), self.integration | ||||||
|  |         ) | ||||||
|  |         container = Container( | ||||||
|  |             attrs={ | ||||||
|  |                 "NetworkSettings": { | ||||||
|  |                     "Ports": { | ||||||
|  |                         "9000/tcp": [{"HostIp": "", "HostPort": "9000"}], | ||||||
|  |                         "9443/tcp": [{"HostIp": "", "HostPort": "9443"}], | ||||||
|  |                     } | ||||||
|  |                 }, | ||||||
|  |                 "State": "", | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |         with self.settings(TEST=False): | ||||||
|  |             self.assertFalse(controller._comp_ports(container)) | ||||||
|  |             container.attrs["State"] = "running" | ||||||
|  |             self.assertFalse(controller._comp_ports(container)) | ||||||
|  |  | ||||||
|  |     def test_port_invalid(self): | ||||||
|  |         """Test port check""" | ||||||
|  |         controller = ProxyDockerController( | ||||||
|  |             Outpost.objects.filter(managed=MANAGED_OUTPOST).first(), self.integration | ||||||
|  |         ) | ||||||
|  |         container_no_ports = Container( | ||||||
|  |             attrs={"NetworkSettings": {"Ports": None}, "State": "running"} | ||||||
|  |         ) | ||||||
|  |         container_missing_port = Container( | ||||||
|  |             attrs={ | ||||||
|  |                 "NetworkSettings": { | ||||||
|  |                     "Ports": { | ||||||
|  |                         "9443/tcp": [{"HostIp": "", "HostPort": "9443"}], | ||||||
|  |                     } | ||||||
|  |                 }, | ||||||
|  |                 "State": "running", | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |         container_mismatched_host = Container( | ||||||
|  |             attrs={ | ||||||
|  |                 "NetworkSettings": { | ||||||
|  |                     "Ports": { | ||||||
|  |                         "9443/tcp": [{"HostIp": "", "HostPort": "123"}], | ||||||
|  |                     } | ||||||
|  |                 }, | ||||||
|  |                 "State": "running", | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |         with self.settings(TEST=False): | ||||||
|  |             self.assertFalse(controller._comp_ports(container_no_ports)) | ||||||
|  |             self.assertTrue(controller._comp_ports(container_missing_port)) | ||||||
|  |             self.assertTrue(controller._comp_ports(container_mismatched_host)) | ||||||
| @ -90,7 +90,8 @@ class PolicyEngine: | |||||||
|     def build(self) -> "PolicyEngine": |     def build(self) -> "PolicyEngine": | ||||||
|         """Build wrapper which monitors performance""" |         """Build wrapper which monitors performance""" | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
|             op="policy.engine.build" |             op="authentik.policy.engine.build", | ||||||
|  |             description=self.__pbm, | ||||||
|         ) as span, HIST_POLICIES_BUILD_TIME.labels( |         ) as span, HIST_POLICIES_BUILD_TIME.labels( | ||||||
|             object_name=self.__pbm, |             object_name=self.__pbm, | ||||||
|             object_type=f"{self.__pbm._meta.app_label}.{self.__pbm._meta.model_name}", |             object_type=f"{self.__pbm._meta.app_label}.{self.__pbm._meta.model_name}", | ||||||
|  | |||||||
| @ -66,6 +66,7 @@ class Migration(migrations.Migration): | |||||||
|                             ("source_linked", "Source Linked"), |                             ("source_linked", "Source Linked"), | ||||||
|                             ("impersonation_started", "Impersonation Started"), |                             ("impersonation_started", "Impersonation Started"), | ||||||
|                             ("impersonation_ended", "Impersonation Ended"), |                             ("impersonation_ended", "Impersonation Ended"), | ||||||
|  |                             ("flow_execution", "Flow Execution"), | ||||||
|                             ("policy_execution", "Policy Execution"), |                             ("policy_execution", "Policy Execution"), | ||||||
|                             ("policy_exception", "Policy Exception"), |                             ("policy_exception", "Policy Exception"), | ||||||
|                             ("property_mapping_exception", "Property Mapping Exception"), |                             ("property_mapping_exception", "Property Mapping Exception"), | ||||||
|  | |||||||
| @ -11,6 +11,8 @@ from authentik.flows.planner import PLAN_CONTEXT_SSO | |||||||
| from authentik.lib.expression.evaluator import BaseEvaluator | from authentik.lib.expression.evaluator import BaseEvaluator | ||||||
| from authentik.lib.utils.http import get_client_ip | from authentik.lib.utils.http import get_client_ip | ||||||
| from authentik.policies.exceptions import PolicyException | from authentik.policies.exceptions import PolicyException | ||||||
|  | from authentik.policies.models import Policy, PolicyBinding | ||||||
|  | from authentik.policies.process import PolicyProcess | ||||||
| from authentik.policies.types import PolicyRequest, PolicyResult | from authentik.policies.types import PolicyRequest, PolicyResult | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -31,6 +33,7 @@ class PolicyEvaluator(BaseEvaluator): | |||||||
|         self._context["ak_logger"] = get_logger(policy_name) |         self._context["ak_logger"] = get_logger(policy_name) | ||||||
|         self._context["ak_message"] = self.expr_func_message |         self._context["ak_message"] = self.expr_func_message | ||||||
|         self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator |         self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator | ||||||
|  |         self._context["ak_call_policy"] = self.expr_func_call_policy | ||||||
|         self._context["ip_address"] = ip_address |         self._context["ip_address"] = ip_address | ||||||
|         self._context["ip_network"] = ip_network |         self._context["ip_network"] = ip_network | ||||||
|         self._filename = policy_name or "PolicyEvaluator" |         self._filename = policy_name or "PolicyEvaluator" | ||||||
| @ -39,6 +42,16 @@ class PolicyEvaluator(BaseEvaluator): | |||||||
|         """Wrapper to append to messages list, which is returned with PolicyResult""" |         """Wrapper to append to messages list, which is returned with PolicyResult""" | ||||||
|         self._messages.append(message) |         self._messages.append(message) | ||||||
|  |  | ||||||
|  |     def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult: | ||||||
|  |         """Call policy by name, with current request""" | ||||||
|  |         policy = Policy.objects.filter(name=name).select_subclasses().first() | ||||||
|  |         if not policy: | ||||||
|  |             raise ValueError(f"Policy '{name}' not found.") | ||||||
|  |         req: PolicyRequest = self._context["request"] | ||||||
|  |         req.context.update(kwargs) | ||||||
|  |         proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) | ||||||
|  |         return proc.profiling_wrapper() | ||||||
|  |  | ||||||
|     def expr_func_user_has_authenticator( |     def expr_func_user_has_authenticator( | ||||||
|         self, user: User, device_type: Optional[str] = None |         self, user: User, device_type: Optional[str] = None | ||||||
|     ) -> bool: |     ) -> bool: | ||||||
|  | |||||||
| @ -74,4 +74,4 @@ class TestExpressionPolicyAPI(APITestCase): | |||||||
|         expr = "return True" |         expr = "return True" | ||||||
|         self.assertEqual(ExpressionPolicySerializer().validate_expression(expr), expr) |         self.assertEqual(ExpressionPolicySerializer().validate_expression(expr), expr) | ||||||
|         with self.assertRaises(ValidationError): |         with self.assertRaises(ValidationError): | ||||||
|             print(ExpressionPolicySerializer().validate_expression("/")) |             ExpressionPolicySerializer().validate_expression("/") | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ class PasswordPolicySerializer(PolicySerializer): | |||||||
|         model = PasswordPolicy |         model = PasswordPolicy | ||||||
|         fields = PolicySerializer.Meta.fields + [ |         fields = PolicySerializer.Meta.fields + [ | ||||||
|             "password_field", |             "password_field", | ||||||
|  |             "amount_digits", | ||||||
|             "amount_uppercase", |             "amount_uppercase", | ||||||
|             "amount_lowercase", |             "amount_lowercase", | ||||||
|             "amount_symbols", |             "amount_symbols", | ||||||
|  | |||||||
| @ -0,0 +1,38 @@ | |||||||
|  | # Generated by Django 4.0 on 2021-12-18 14:54 | ||||||
|  |  | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_policies_password", "0002_passwordpolicy_password_field"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="passwordpolicy", | ||||||
|  |             name="amount_digits", | ||||||
|  |             field=models.PositiveIntegerField(default=0), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="passwordpolicy", | ||||||
|  |             name="amount_lowercase", | ||||||
|  |             field=models.PositiveIntegerField(default=0), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="passwordpolicy", | ||||||
|  |             name="amount_symbols", | ||||||
|  |             field=models.PositiveIntegerField(default=0), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="passwordpolicy", | ||||||
|  |             name="amount_uppercase", | ||||||
|  |             field=models.PositiveIntegerField(default=0), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="passwordpolicy", | ||||||
|  |             name="length_min", | ||||||
|  |             field=models.PositiveIntegerField(default=0), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -13,6 +13,7 @@ from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| RE_LOWER = re.compile("[a-z]") | RE_LOWER = re.compile("[a-z]") | ||||||
| RE_UPPER = re.compile("[A-Z]") | RE_UPPER = re.compile("[A-Z]") | ||||||
|  | RE_DIGITS = re.compile("[0-9]") | ||||||
|  |  | ||||||
|  |  | ||||||
| class PasswordPolicy(Policy): | class PasswordPolicy(Policy): | ||||||
| @ -23,10 +24,11 @@ class PasswordPolicy(Policy): | |||||||
|         help_text=_("Field key to check, field keys defined in Prompt stages are available."), |         help_text=_("Field key to check, field keys defined in Prompt stages are available."), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     amount_uppercase = models.IntegerField(default=0) |     amount_digits = models.PositiveIntegerField(default=0) | ||||||
|     amount_lowercase = models.IntegerField(default=0) |     amount_uppercase = models.PositiveIntegerField(default=0) | ||||||
|     amount_symbols = models.IntegerField(default=0) |     amount_lowercase = models.PositiveIntegerField(default=0) | ||||||
|     length_min = models.IntegerField(default=0) |     amount_symbols = models.PositiveIntegerField(default=0) | ||||||
|  |     length_min = models.PositiveIntegerField(default=0) | ||||||
|     symbol_charset = models.TextField(default=r"!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ ") |     symbol_charset = models.TextField(default=r"!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ ") | ||||||
|     error_message = models.TextField() |     error_message = models.TextField() | ||||||
|  |  | ||||||
| @ -40,6 +42,7 @@ class PasswordPolicy(Policy): | |||||||
|     def component(self) -> str: |     def component(self) -> str: | ||||||
|         return "ak-policy-password-form" |         return "ak-policy-password-form" | ||||||
|  |  | ||||||
|  |     # pylint: disable=too-many-return-statements | ||||||
|     def passes(self, request: PolicyRequest) -> PolicyResult: |     def passes(self, request: PolicyRequest) -> PolicyResult: | ||||||
|         if ( |         if ( | ||||||
|             self.password_field not in request.context |             self.password_field not in request.context | ||||||
| @ -62,6 +65,9 @@ class PasswordPolicy(Policy): | |||||||
|             LOGGER.debug("password failed", reason="length") |             LOGGER.debug("password failed", reason="length") | ||||||
|             return PolicyResult(False, self.error_message) |             return PolicyResult(False, self.error_message) | ||||||
|  |  | ||||||
|  |         if self.amount_digits > 0 and len(RE_DIGITS.findall(password)) < self.amount_digits: | ||||||
|  |             LOGGER.debug("password failed", reason="amount_digits") | ||||||
|  |             return PolicyResult(False, self.error_message) | ||||||
|         if self.amount_lowercase > 0 and len(RE_LOWER.findall(password)) < self.amount_lowercase: |         if self.amount_lowercase > 0 and len(RE_LOWER.findall(password)) < self.amount_lowercase: | ||||||
|             LOGGER.debug("password failed", reason="amount_lowercase") |             LOGGER.debug("password failed", reason="amount_lowercase") | ||||||
|             return PolicyResult(False, self.error_message) |             return PolicyResult(False, self.error_message) | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ class TestPasswordPolicy(TestCase): | |||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         self.policy = PasswordPolicy.objects.create( |         self.policy = PasswordPolicy.objects.create( | ||||||
|             name="test_false", |             name="test_false", | ||||||
|  |             amount_digits=1, | ||||||
|             amount_uppercase=1, |             amount_uppercase=1, | ||||||
|             amount_lowercase=2, |             amount_lowercase=2, | ||||||
|             amount_symbols=3, |             amount_symbols=3, | ||||||
| @ -38,7 +39,7 @@ class TestPasswordPolicy(TestCase): | |||||||
|     def test_failed_lowercase(self): |     def test_failed_lowercase(self): | ||||||
|         """not enough lowercase""" |         """not enough lowercase""" | ||||||
|         request = PolicyRequest(get_anonymous_user()) |         request = PolicyRequest(get_anonymous_user()) | ||||||
|         request.context["password"] = "TTTTTTTTTTTTTTTTTTTTTTTe"  # nosec |         request.context["password"] = "1TTTTTTTTTTTTTTTTTTTTTTe"  # nosec | ||||||
|         result: PolicyResult = self.policy.passes(request) |         result: PolicyResult = self.policy.passes(request) | ||||||
|         self.assertFalse(result.passing) |         self.assertFalse(result.passing) | ||||||
|         self.assertEqual(result.messages, ("test message",)) |         self.assertEqual(result.messages, ("test message",)) | ||||||
| @ -46,15 +47,23 @@ class TestPasswordPolicy(TestCase): | |||||||
|     def test_failed_uppercase(self): |     def test_failed_uppercase(self): | ||||||
|         """not enough uppercase""" |         """not enough uppercase""" | ||||||
|         request = PolicyRequest(get_anonymous_user()) |         request = PolicyRequest(get_anonymous_user()) | ||||||
|         request.context["password"] = "tttttttttttttttttttttttE"  # nosec |         request.context["password"] = "1tttttttttttttttttttttE"  # nosec | ||||||
|         result: PolicyResult = self.policy.passes(request) |         result: PolicyResult = self.policy.passes(request) | ||||||
|         self.assertFalse(result.passing) |         self.assertFalse(result.passing) | ||||||
|         self.assertEqual(result.messages, ("test message",)) |         self.assertEqual(result.messages, ("test message",)) | ||||||
|  |  | ||||||
|     def test_failed_symbols(self): |     def test_failed_symbols(self): | ||||||
|         """not enough uppercase""" |         """not enough symbols""" | ||||||
|         request = PolicyRequest(get_anonymous_user()) |         request = PolicyRequest(get_anonymous_user()) | ||||||
|         request.context["password"] = "TETETETETETETETETETETETETe!!!"  # nosec |         request.context["password"] = "1ETETETETETETETETETETETETe!!!"  # nosec | ||||||
|  |         result: PolicyResult = self.policy.passes(request) | ||||||
|  |         self.assertFalse(result.passing) | ||||||
|  |         self.assertEqual(result.messages, ("test message",)) | ||||||
|  |  | ||||||
|  |     def test_failed_digits(self): | ||||||
|  |         """not enough digits""" | ||||||
|  |         request = PolicyRequest(get_anonymous_user()) | ||||||
|  |         request.context["password"] = "TETETETETETETETETETETE1e!!!"  # nosec | ||||||
|         result: PolicyResult = self.policy.passes(request) |         result: PolicyResult = self.policy.passes(request) | ||||||
|         self.assertFalse(result.passing) |         self.assertFalse(result.passing) | ||||||
|         self.assertEqual(result.messages, ("test message",)) |         self.assertEqual(result.messages, ("test message",)) | ||||||
| @ -62,7 +71,7 @@ class TestPasswordPolicy(TestCase): | |||||||
|     def test_true(self): |     def test_true(self): | ||||||
|         """Positive password case""" |         """Positive password case""" | ||||||
|         request = PolicyRequest(get_anonymous_user()) |         request = PolicyRequest(get_anonymous_user()) | ||||||
|         request.context["password"] = generate_key() + "ee!!!"  # nosec |         request.context["password"] = generate_key() + "1ee!!!"  # nosec | ||||||
|         result: PolicyResult = self.policy.passes(request) |         result: PolicyResult = self.policy.passes(request) | ||||||
|         self.assertTrue(result.passing) |         self.assertTrue(result.passing) | ||||||
|         self.assertEqual(result.messages, tuple()) |         self.assertEqual(result.messages, tuple()) | ||||||
|  | |||||||
| @ -127,10 +127,10 @@ class PolicyProcess(PROCESS_CLASS): | |||||||
|         ) |         ) | ||||||
|         return policy_result |         return policy_result | ||||||
|  |  | ||||||
|     def run(self):  # pragma: no cover |     def profiling_wrapper(self): | ||||||
|         """Task wrapper to run policy checking""" |         """Run with profiling enabled""" | ||||||
|         with Hub.current.start_span( |         with Hub.current.start_span( | ||||||
|             op="policy.process.execute", |             op="authentik.policy.process.execute", | ||||||
|         ) as span, HIST_POLICIES_EXECUTION_TIME.labels( |         ) as span, HIST_POLICIES_EXECUTION_TIME.labels( | ||||||
|             binding_order=self.binding.order, |             binding_order=self.binding.order, | ||||||
|             binding_target_type=self.binding.target_type, |             binding_target_type=self.binding.target_type, | ||||||
| @ -142,8 +142,12 @@ class PolicyProcess(PROCESS_CLASS): | |||||||
|             span: Span |             span: Span | ||||||
|             span.set_data("policy", self.binding.policy) |             span.set_data("policy", self.binding.policy) | ||||||
|             span.set_data("request", self.request) |             span.set_data("request", self.request) | ||||||
|             try: |             return self.execute() | ||||||
|                 self.connection.send(self.execute()) |  | ||||||
|             except Exception as exc:  # pylint: disable=broad-except |     def run(self):  # pragma: no cover | ||||||
|                 LOGGER.warning(str(exc)) |         """Task wrapper to run policy checking""" | ||||||
|                 self.connection.send(PolicyResult(False, str(exc))) |         try: | ||||||
|  |             self.connection.send(self.profiling_wrapper()) | ||||||
|  |         except Exception as exc:  # pylint: disable=broad-except | ||||||
|  |             LOGGER.warning(str(exc)) | ||||||
|  |             self.connection.send(PolicyResult(False, str(exc))) | ||||||
|  | |||||||
| @ -2,7 +2,12 @@ | |||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus | from authentik.events.monitored_tasks import ( | ||||||
|  |     MonitoredTask, | ||||||
|  |     TaskResult, | ||||||
|  |     TaskResultStatus, | ||||||
|  |     prefill_task, | ||||||
|  | ) | ||||||
| from authentik.policies.reputation.models import IPReputation, UserReputation | from authentik.policies.reputation.models import IPReputation, UserReputation | ||||||
| from authentik.policies.reputation.signals import CACHE_KEY_IP_PREFIX, CACHE_KEY_USER_PREFIX | from authentik.policies.reputation.signals import CACHE_KEY_IP_PREFIX, CACHE_KEY_USER_PREFIX | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
| @ -10,8 +15,9 @@ from authentik.root.celery import CELERY_APP | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def save_ip_reputation(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def save_ip_reputation(self: MonitoredTask): | ||||||
|     """Save currently cached reputation to database""" |     """Save currently cached reputation to database""" | ||||||
|     objects_to_update = [] |     objects_to_update = [] | ||||||
|     for key, score in cache.get_many(cache.keys(CACHE_KEY_IP_PREFIX + "*")).items(): |     for key, score in cache.get_many(cache.keys(CACHE_KEY_IP_PREFIX + "*")).items(): | ||||||
| @ -23,8 +29,9 @@ def save_ip_reputation(self: PrefilledMonitoredTask): | |||||||
|     self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated IP Reputation"])) |     self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated IP Reputation"])) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def save_user_reputation(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def save_user_reputation(self: MonitoredTask): | ||||||
|     """Save currently cached reputation to database""" |     """Save currently cached reputation to database""" | ||||||
|     objects_to_update = [] |     objects_to_update = [] | ||||||
|     for key, score in cache.get_many(cache.keys(CACHE_KEY_USER_PREFIX + "*")).items(): |     for key, score in cache.get_many(cache.keys(CACHE_KEY_USER_PREFIX + "*")).items(): | ||||||
|  | |||||||
| @ -23,6 +23,6 @@ def invalidate_policy_cache(sender, instance, **_): | |||||||
|             total += len(keys) |             total += len(keys) | ||||||
|             cache.delete_many(keys) |             cache.delete_many(keys) | ||||||
|         LOGGER.debug("Invalidating policy cache", policy=instance, keys=total) |         LOGGER.debug("Invalidating policy cache", policy=instance, keys=total) | ||||||
|     # Also delete user application cache |         # Also delete user application cache | ||||||
|     keys = cache.keys(user_app_cache_key("*")) or [] |         keys = cache.keys(user_app_cache_key("*")) or [] | ||||||
|     cache.delete_many(keys) |         cache.delete_many(keys) | ||||||
|  | |||||||
| @ -8,7 +8,6 @@ from datetime import datetime | |||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
| from typing import Any, Optional, Type | from typing import Any, Optional, Type | ||||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||||
| from uuid import uuid4 |  | ||||||
|  |  | ||||||
| from dacite import from_dict | from dacite import from_dict | ||||||
| from django.db import models | from django.db import models | ||||||
| @ -225,7 +224,7 @@ class OAuth2Provider(Provider): | |||||||
|         token = RefreshToken( |         token = RefreshToken( | ||||||
|             user=user, |             user=user, | ||||||
|             provider=self, |             provider=self, | ||||||
|             refresh_token=uuid4().hex, |             refresh_token=generate_key(), | ||||||
|             expires=timezone.now() + timedelta_from_string(self.token_validity), |             expires=timezone.now() + timedelta_from_string(self.token_validity), | ||||||
|             scope=scope, |             scope=scope, | ||||||
|         ) |         ) | ||||||
| @ -434,7 +433,7 @@ class RefreshToken(ExpiringModel, BaseGrantModel): | |||||||
|         """Create access token with a similar format as Okta, Keycloak, ADFS""" |         """Create access token with a similar format as Okta, Keycloak, ADFS""" | ||||||
|         token = self.create_id_token(user, request).to_dict() |         token = self.create_id_token(user, request).to_dict() | ||||||
|         token["cid"] = self.provider.client_id |         token["cid"] = self.provider.client_id | ||||||
|         token["uid"] = uuid4().hex |         token["uid"] = generate_key() | ||||||
|         return self.provider.encode(token) |         return self.provider.encode(token) | ||||||
|  |  | ||||||
|     def create_id_token(self, user: User, request: HttpRequest) -> IDToken: |     def create_id_token(self, user: User, request: HttpRequest) -> IDToken: | ||||||
|  | |||||||
| @ -95,9 +95,15 @@ class TokenParams: | |||||||
|                 self.refresh_token = RefreshToken.objects.get( |                 self.refresh_token = RefreshToken.objects.get( | ||||||
|                     refresh_token=raw_token, provider=self.provider |                     refresh_token=raw_token, provider=self.provider | ||||||
|                 ) |                 ) | ||||||
|  |                 if self.refresh_token.is_expired: | ||||||
|  |                     LOGGER.warning( | ||||||
|  |                         "Refresh token is expired", | ||||||
|  |                         token=raw_token, | ||||||
|  |                     ) | ||||||
|  |                     raise TokenError("invalid_grant") | ||||||
|                 # https://tools.ietf.org/html/rfc6749#section-6 |                 # https://tools.ietf.org/html/rfc6749#section-6 | ||||||
|                 # Fallback to original token's scopes when none are given |                 # Fallback to original token's scopes when none are given | ||||||
|                 if self.scope == []: |                 if not self.scope: | ||||||
|                     self.scope = self.refresh_token.scope |                     self.scope = self.refresh_token.scope | ||||||
|             except RefreshToken.DoesNotExist: |             except RefreshToken.DoesNotExist: | ||||||
|                 LOGGER.warning( |                 LOGGER.warning( | ||||||
| @ -138,6 +144,12 @@ class TokenParams: | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             self.authorization_code = AuthorizationCode.objects.get(code=raw_code) |             self.authorization_code = AuthorizationCode.objects.get(code=raw_code) | ||||||
|  |             if self.authorization_code.is_expired: | ||||||
|  |                 LOGGER.warning( | ||||||
|  |                     "Code is expired", | ||||||
|  |                     token=raw_code, | ||||||
|  |                 ) | ||||||
|  |                 raise TokenError("invalid_grant") | ||||||
|         except AuthorizationCode.DoesNotExist: |         except AuthorizationCode.DoesNotExist: | ||||||
|             LOGGER.warning("Code does not exist", code=raw_code) |             LOGGER.warning("Code does not exist", code=raw_code) | ||||||
|             raise TokenError("invalid_grant") |             raise TokenError("invalid_grant") | ||||||
| @ -194,8 +206,10 @@ class TokenView(View): | |||||||
|             self.params = TokenParams.parse(request, self.provider, client_id, client_secret) |             self.params = TokenParams.parse(request, self.provider, client_id, client_secret) | ||||||
|  |  | ||||||
|             if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: |             if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: | ||||||
|  |                 LOGGER.info("Converting authorization code to refresh token") | ||||||
|                 return TokenResponse(self.create_code_response()) |                 return TokenResponse(self.create_code_response()) | ||||||
|             if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN: |             if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN: | ||||||
|  |                 LOGGER.info("Refreshing refresh token") | ||||||
|                 return TokenResponse(self.create_refresh_response()) |                 return TokenResponse(self.create_refresh_response()) | ||||||
|             raise ValueError(f"Invalid grant_type: {self.params.grant_type}") |             raise ValueError(f"Invalid grant_type: {self.params.grant_type}") | ||||||
|         except TokenError as error: |         except TokenError as error: | ||||||
|  | |||||||
| @ -36,6 +36,7 @@ from authentik.flows.models import Flow, FlowDesignation | |||||||
| from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider | from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider | ||||||
| from authentik.providers.saml.processors.metadata import MetadataProcessor | from authentik.providers.saml.processors.metadata import MetadataProcessor | ||||||
| from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser | from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser | ||||||
|  | from authentik.sources.saml.processors.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -109,7 +110,17 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): | |||||||
|                 name="download", |                 name="download", | ||||||
|                 location=OpenApiParameter.QUERY, |                 location=OpenApiParameter.QUERY, | ||||||
|                 type=OpenApiTypes.BOOL, |                 type=OpenApiTypes.BOOL, | ||||||
|             ) |             ), | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 name="force_binding", | ||||||
|  |                 location=OpenApiParameter.QUERY, | ||||||
|  |                 type=OpenApiTypes.STR, | ||||||
|  |                 enum=[ | ||||||
|  |                     SAML_BINDING_REDIRECT, | ||||||
|  |                     SAML_BINDING_POST, | ||||||
|  |                 ], | ||||||
|  |                 description=("Optionally force the metadata to only include one binding."), | ||||||
|  |             ), | ||||||
|         ], |         ], | ||||||
|     ) |     ) | ||||||
|     @action(methods=["GET"], detail=True, permission_classes=[AllowAny]) |     @action(methods=["GET"], detail=True, permission_classes=[AllowAny]) | ||||||
| @ -122,8 +133,10 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): | |||||||
|         except ValueError: |         except ValueError: | ||||||
|             raise Http404 |             raise Http404 | ||||||
|         try: |         try: | ||||||
|             metadata = MetadataProcessor(provider, request).build_entity_descriptor() |             proc = MetadataProcessor(provider, request) | ||||||
|             if "download" in request._request.GET: |             proc.force_binding = request.query_params.get("force_binding", None) | ||||||
|  |             metadata = proc.build_entity_descriptor() | ||||||
|  |             if "download" in request.query_params: | ||||||
|                 response = HttpResponse(metadata, content_type="application/xml") |                 response = HttpResponse(metadata, content_type="application/xml") | ||||||
|                 response[ |                 response[ | ||||||
|                     "Content-Disposition" |                     "Content-Disposition" | ||||||
|  | |||||||
| @ -21,7 +21,7 @@ class Migration(migrations.Migration): | |||||||
|             name="audience", |             name="audience", | ||||||
|             field=models.TextField( |             field=models.TextField( | ||||||
|                 default="", |                 default="", | ||||||
|                 help_text="Value of the audience restriction field of the asseration.", |                 help_text="Value of the audience restriction field of the assertion.", | ||||||
|             ), |             ), | ||||||
|         ), |         ), | ||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ class Migration(migrations.Migration): | |||||||
|             field=models.TextField( |             field=models.TextField( | ||||||
|                 blank=True, |                 blank=True, | ||||||
|                 default="", |                 default="", | ||||||
|                 help_text="Value of the audience restriction field of the asseration. When left empty, no audience restriction will be added.", |                 help_text="Value of the audience restriction field of the assertion. When left empty, no audience restriction will be added.", | ||||||
|             ), |             ), | ||||||
|         ), |         ), | ||||||
|     ] |     ] | ||||||
|  | |||||||
| @ -41,7 +41,7 @@ class SAMLProvider(Provider): | |||||||
|         blank=True, |         blank=True, | ||||||
|         help_text=_( |         help_text=_( | ||||||
|             ( |             ( | ||||||
|                 "Value of the audience restriction field of the asseration. When left empty, " |                 "Value of the audience restriction field of the assertion. When left empty, " | ||||||
|                 "no audience restriction will be added." |                 "no audience restriction will be added." | ||||||
|             ) |             ) | ||||||
|         ), |         ), | ||||||
|  | |||||||
| @ -70,13 +70,14 @@ class AssertionProcessor: | |||||||
|         """Get AttributeStatement Element with Attributes from Property Mappings.""" |         """Get AttributeStatement Element with Attributes from Property Mappings.""" | ||||||
|         # https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions |         # https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions | ||||||
|         attribute_statement = Element(f"{{{NS_SAML_ASSERTION}}}AttributeStatement") |         attribute_statement = Element(f"{{{NS_SAML_ASSERTION}}}AttributeStatement") | ||||||
|  |         user = self.http_request.user | ||||||
|         for mapping in self.provider.property_mappings.all().select_subclasses(): |         for mapping in self.provider.property_mappings.all().select_subclasses(): | ||||||
|             if not isinstance(mapping, SAMLPropertyMapping): |             if not isinstance(mapping, SAMLPropertyMapping): | ||||||
|                 continue |                 continue | ||||||
|             try: |             try: | ||||||
|                 mapping: SAMLPropertyMapping |                 mapping: SAMLPropertyMapping | ||||||
|                 value = mapping.evaluate( |                 value = mapping.evaluate( | ||||||
|                     user=self.http_request.user, |                     user=user, | ||||||
|                     request=self.http_request, |                     request=self.http_request, | ||||||
|                     provider=self.provider, |                     provider=self.provider, | ||||||
|                 ) |                 ) | ||||||
| @ -101,7 +102,8 @@ class AssertionProcessor: | |||||||
|  |  | ||||||
|                 attribute_statement.append(attribute) |                 attribute_statement.append(attribute) | ||||||
|  |  | ||||||
|             except PropertyMappingExpressionException as exc: |             except (PropertyMappingExpressionException, ValueError) as exc: | ||||||
|  |                 # Value error can be raised when assigning invalid data to an attribute | ||||||
|                 Event.new( |                 Event.new( | ||||||
|                     EventAction.CONFIGURATION_ERROR, |                     EventAction.CONFIGURATION_ERROR, | ||||||
|                     message=f"Failed to evaluate property-mapping: {str(exc)}", |                     message=f"Failed to evaluate property-mapping: {str(exc)}", | ||||||
|  | |||||||
| @ -29,10 +29,12 @@ class MetadataProcessor: | |||||||
|  |  | ||||||
|     provider: SAMLProvider |     provider: SAMLProvider | ||||||
|     http_request: HttpRequest |     http_request: HttpRequest | ||||||
|  |     force_binding: Optional[str] | ||||||
|  |  | ||||||
|     def __init__(self, provider: SAMLProvider, request: HttpRequest): |     def __init__(self, provider: SAMLProvider, request: HttpRequest): | ||||||
|         self.provider = provider |         self.provider = provider | ||||||
|         self.http_request = request |         self.http_request = request | ||||||
|  |         self.force_binding = None | ||||||
|         self.xml_id = get_random_id() |         self.xml_id = get_random_id() | ||||||
|  |  | ||||||
|     def get_signing_key_descriptor(self) -> Optional[Element]: |     def get_signing_key_descriptor(self) -> Optional[Element]: | ||||||
| @ -79,6 +81,8 @@ class MetadataProcessor: | |||||||
|             ), |             ), | ||||||
|         } |         } | ||||||
|         for binding, url in binding_url_map.items(): |         for binding, url in binding_url_map.items(): | ||||||
|  |             if self.force_binding and self.force_binding != binding: | ||||||
|  |                 continue | ||||||
|             element = Element(f"{{{NS_SAML_METADATA}}}SingleSignOnService") |             element = Element(f"{{{NS_SAML_METADATA}}}SingleSignOnService") | ||||||
|             element.attrib["Binding"] = binding |             element.attrib["Binding"] = binding | ||||||
|             element.attrib["Location"] = url |             element.attrib["Location"] = url | ||||||
|  | |||||||
| @ -125,7 +125,7 @@ class SAMLSSOBindingPOSTView(SAMLSSOView): | |||||||
|         # This happens when using POST bindings but the user isn't logged in |         # This happens when using POST bindings but the user isn't logged in | ||||||
|         # (user gets redirected and POST body is 'lost') |         # (user gets redirected and POST body is 'lost') | ||||||
|         if SESSION_KEY_POST in self.request.session: |         if SESSION_KEY_POST in self.request.session: | ||||||
|             payload = self.request.session[SESSION_KEY_POST] |             payload = self.request.session.pop(SESSION_KEY_POST) | ||||||
|         if REQUEST_KEY_SAML_REQUEST not in payload: |         if REQUEST_KEY_SAML_REQUEST not in payload: | ||||||
|             LOGGER.info("check_saml_request: SAML payload missing") |             LOGGER.info("check_saml_request: SAML payload missing") | ||||||
|             return bad_request_message(self.request, "The SAML request payload is missing.") |             return bad_request_message(self.request, "The SAML request payload is missing.") | ||||||
|  | |||||||
| @ -14,6 +14,7 @@ from celery.signals import ( | |||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.core.middleware import LOCAL | ||||||
| from authentik.lib.sentry import before_send | from authentik.lib.sentry import before_send | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
|  |  | ||||||
| @ -26,7 +27,7 @@ CELERY_APP = Celery("authentik") | |||||||
|  |  | ||||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||||
| @setup_logging.connect | @setup_logging.connect | ||||||
| def config_loggers(*args, **kwags): | def config_loggers(*args, **kwargs): | ||||||
|     """Apply logging settings from settings.py to celery""" |     """Apply logging settings from settings.py to celery""" | ||||||
|     dictConfig(settings.LOGGING) |     dictConfig(settings.LOGGING) | ||||||
|  |  | ||||||
| @ -36,21 +37,29 @@ def config_loggers(*args, **kwags): | |||||||
| def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs): | def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs): | ||||||
|     """Log task_id after it was published""" |     """Log task_id after it was published""" | ||||||
|     info = headers if "task" in headers else body |     info = headers if "task" in headers else body | ||||||
|     LOGGER.debug("Task published", task_id=info.get("id", ""), task_name=info.get("task", "")) |     LOGGER.info("Task published", task_id=info.get("id", ""), task_name=info.get("task", "")) | ||||||
|  |  | ||||||
|  |  | ||||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||||
| @task_prerun.connect | @task_prerun.connect | ||||||
| def task_prerun_hook(task_id, task, *args, **kwargs): | def task_prerun_hook(task_id: str, task, *args, **kwargs): | ||||||
|     """Log task_id on worker""" |     """Log task_id on worker""" | ||||||
|     LOGGER.debug("Task started", task_id=task_id, task_name=task.__name__) |     request_id = "task-" + task_id.replace("-", "") | ||||||
|  |     LOCAL.authentik_task = { | ||||||
|  |         "request_id": request_id, | ||||||
|  |     } | ||||||
|  |     LOGGER.info("Task started", task_id=task_id, task_name=task.__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||||
| @task_postrun.connect | @task_postrun.connect | ||||||
| def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs): | def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs): | ||||||
|     """Log task_id on worker""" |     """Log task_id on worker""" | ||||||
|     LOGGER.debug("Task finished", task_id=task_id, task_name=task.__name__, state=state) |     LOGGER.info("Task finished", task_id=task_id, task_name=task.__name__, state=state) | ||||||
|  |     if not hasattr(LOCAL, "authentik_task"): | ||||||
|  |         return | ||||||
|  |     for key in list(LOCAL.authentik_task.keys()): | ||||||
|  |         del LOCAL.authentik_task[key] | ||||||
|  |  | ||||||
|  |  | ||||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||||
|  | |||||||
| @ -24,9 +24,11 @@ import structlog | |||||||
| from celery.schedules import crontab | from celery.schedules import crontab | ||||||
| from sentry_sdk import init as sentry_init | from sentry_sdk import init as sentry_init | ||||||
| from sentry_sdk.api import set_tag | from sentry_sdk.api import set_tag | ||||||
|  | from sentry_sdk.integrations.boto3 import Boto3Integration | ||||||
| from sentry_sdk.integrations.celery import CeleryIntegration | from sentry_sdk.integrations.celery import CeleryIntegration | ||||||
| from sentry_sdk.integrations.django import DjangoIntegration | from sentry_sdk.integrations.django import DjangoIntegration | ||||||
| from sentry_sdk.integrations.redis import RedisIntegration | from sentry_sdk.integrations.redis import RedisIntegration | ||||||
|  | from sentry_sdk.integrations.threading import ThreadingIntegration | ||||||
|  |  | ||||||
| from authentik import ENV_GIT_HASH_KEY, __version__ | from authentik import ENV_GIT_HASH_KEY, __version__ | ||||||
| from authentik.core.middleware import structlog_add_request_id | from authentik.core.middleware import structlog_add_request_id | ||||||
| @ -231,6 +233,7 @@ CACHES = { | |||||||
|         "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, |         "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | DJANGO_REDIS_SCAN_ITERSIZE = 1000 | ||||||
| DJANGO_REDIS_IGNORE_EXCEPTIONS = True | DJANGO_REDIS_IGNORE_EXCEPTIONS = True | ||||||
| DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True | DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True | ||||||
| SESSION_ENGINE = "django.contrib.sessions.backends.cache" | SESSION_ENGINE = "django.contrib.sessions.backends.cache" | ||||||
| @ -342,8 +345,6 @@ TIME_ZONE = "UTC" | |||||||
|  |  | ||||||
| USE_I18N = True | USE_I18N = True | ||||||
|  |  | ||||||
| USE_L10N = True |  | ||||||
|  |  | ||||||
| USE_TZ = True | USE_TZ = True | ||||||
|  |  | ||||||
| LOCALE_PATHS = ["./locale"] | LOCALE_PATHS = ["./locale"] | ||||||
| @ -421,6 +422,8 @@ if _ERROR_REPORTING: | |||||||
|             DjangoIntegration(transaction_style="function_name"), |             DjangoIntegration(transaction_style="function_name"), | ||||||
|             CeleryIntegration(), |             CeleryIntegration(), | ||||||
|             RedisIntegration(), |             RedisIntegration(), | ||||||
|  |             Boto3Integration(), | ||||||
|  |             ThreadingIntegration(propagate_hub=True), | ||||||
|         ], |         ], | ||||||
|         before_send=before_send, |         before_send=before_send, | ||||||
|         release=f"authentik@{__version__}", |         release=f"authentik@{__version__}", | ||||||
|  | |||||||
| @ -1,4 +1,6 @@ | |||||||
| """Integrate ./manage.py test with pytest""" | """Integrate ./manage.py test with pytest""" | ||||||
|  | from argparse import ArgumentParser | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| @ -8,34 +10,43 @@ from tests.e2e.utils import get_docker_tag | |||||||
| class PytestTestRunner:  # pragma: no cover | class PytestTestRunner:  # pragma: no cover | ||||||
|     """Runs pytest to discover and run tests.""" |     """Runs pytest to discover and run tests.""" | ||||||
|  |  | ||||||
|     def __init__(self, verbosity=1, failfast=False, keepdb=False, **_): |     def __init__(self, verbosity=1, failfast=False, keepdb=False, **kwargs): | ||||||
|         self.verbosity = verbosity |         self.verbosity = verbosity | ||||||
|         self.failfast = failfast |         self.failfast = failfast | ||||||
|         self.keepdb = keepdb |         self.keepdb = keepdb | ||||||
|  |  | ||||||
|  |         self.args = ["-vv"] | ||||||
|  |         if self.failfast: | ||||||
|  |             self.args.append("--exitfirst") | ||||||
|  |         if self.keepdb: | ||||||
|  |             self.args.append("--reuse-db") | ||||||
|  |  | ||||||
|  |         if kwargs.get("randomly_seed", None): | ||||||
|  |             self.args.append(f"--randomly-seed={kwargs['randomly_seed']}") | ||||||
|  |  | ||||||
|         settings.TEST = True |         settings.TEST = True | ||||||
|         settings.CELERY_TASK_ALWAYS_EAGER = True |         settings.CELERY_TASK_ALWAYS_EAGER = True | ||||||
|         CONFIG.y_set("authentik.avatars", "none") |         CONFIG.y_set("authentik.avatars", "none") | ||||||
|         CONFIG.y_set("authentik.geoip", "tests/GeoLite2-City-Test.mmdb") |         CONFIG.y_set("authentik.geoip", "tests/GeoLite2-City-Test.mmdb") | ||||||
|         CONFIG.y_set( |         CONFIG.y_set( | ||||||
|             "outposts.container_image_base", |             "outposts.container_image_base", | ||||||
|             f"goauthentik.io/dev-%(type)s:{get_docker_tag()}", |             f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def add_arguments(cls, parser: ArgumentParser): | ||||||
|  |         """Add more pytest-specific arguments""" | ||||||
|  |         parser.add_argument("--randomly-seed", type=int) | ||||||
|  |  | ||||||
|     def run_tests(self, test_labels): |     def run_tests(self, test_labels): | ||||||
|         """Run pytest and return the exitcode. |         """Run pytest and return the exitcode. | ||||||
|  |  | ||||||
|         It translates some of Django's test command option to pytest's. |         It translates some of Django's test command option to pytest's. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         import pytest |         import pytest | ||||||
|  |  | ||||||
|         argv = ["-vv"] |  | ||||||
|         if self.failfast: |  | ||||||
|             argv.append("--exitfirst") |  | ||||||
|         if self.keepdb: |  | ||||||
|             argv.append("--reuse-db") |  | ||||||
|  |  | ||||||
|         if any("tests/e2e" in label for label in test_labels): |         if any("tests/e2e" in label for label in test_labels): | ||||||
|             argv.append("-pno:randomly") |             self.args.append("-pno:randomly") | ||||||
|  |         self.args.extend(test_labels) | ||||||
|         argv.extend(test_labels) |         return pytest.main(self.args) | ||||||
|         return pytest.main(argv) |  | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """Source API Views""" | """Source API Views""" | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.utils.text import slugify |  | ||||||
| from django_filters.filters import AllValuesMultipleFilter | from django_filters.filters import AllValuesMultipleFilter | ||||||
| from django_filters.filterset import FilterSet | from django_filters.filterset import FilterSet | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| @ -110,7 +109,8 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet): | |||||||
|             GroupLDAPSynchronizer, |             GroupLDAPSynchronizer, | ||||||
|             MembershipLDAPSynchronizer, |             MembershipLDAPSynchronizer, | ||||||
|         ]: |         ]: | ||||||
|             task = TaskInfo.by_name(f"ldap_sync_{slugify(source.name)}-{sync_class.__name__}") |             sync_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower() | ||||||
|  |             task = TaskInfo.by_name(f"ldap_sync_{source.slug}_{sync_name}") | ||||||
|             if task: |             if task: | ||||||
|                 results.append(task) |                 results.append(task) | ||||||
|         return Response(TaskSerializer(results, many=True).data) |         return Response(TaskSerializer(results, many=True).data) | ||||||
|  | |||||||
| @ -29,7 +29,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|             group_dn = self._flatten(self._flatten(group.get("entryDN", group.get("dn")))) |             group_dn = self._flatten(self._flatten(group.get("entryDN", group.get("dn")))) | ||||||
|             if self._source.object_uniqueness_field not in attributes: |             if self._source.object_uniqueness_field not in attributes: | ||||||
|                 self.message( |                 self.message( | ||||||
|                     f"Cannot find uniqueness field in attributes: '{group_dn}", |                     f"Cannot find uniqueness field in attributes: '{group_dn}'", | ||||||
|                     attributes=attributes.keys(), |                     attributes=attributes.keys(), | ||||||
|                     dn=group_dn, |                     dn=group_dn, | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -31,7 +31,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): | |||||||
|             user_dn = self._flatten(user.get("entryDN", user.get("dn"))) |             user_dn = self._flatten(user.get("entryDN", user.get("dn"))) | ||||||
|             if self._source.object_uniqueness_field not in attributes: |             if self._source.object_uniqueness_field not in attributes: | ||||||
|                 self.message( |                 self.message( | ||||||
|                     f"Cannot find uniqueness field in attributes: '{user_dn}", |                     f"Cannot find uniqueness field in attributes: '{user_dn}'", | ||||||
|                     attributes=attributes.keys(), |                     attributes=attributes.keys(), | ||||||
|                     dn=user_dn, |                     dn=user_dn, | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -1,5 +1,4 @@ | |||||||
| """LDAP Sync tasks""" | """LDAP Sync tasks""" | ||||||
| from django.utils.text import slugify |  | ||||||
| from ldap3.core.exceptions import LDAPException | from ldap3.core.exceptions import LDAPException | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| @ -39,7 +38,7 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): | |||||||
|         # to set the state with |         # to set the state with | ||||||
|         return |         return | ||||||
|     sync = path_to_class(sync_class) |     sync = path_to_class(sync_class) | ||||||
|     self.set_uid(f"{slugify(source.name)}_{sync.__name__.replace('LDAPSynchronizer', '').lower()}") |     self.set_uid(f"{source.slug}_{sync.__name__.replace('LDAPSynchronizer', '').lower()}") | ||||||
|     try: |     try: | ||||||
|         sync_inst = sync(source) |         sync_inst = sync(source) | ||||||
|         count = sync_inst.sync() |         count = sync_inst.sync() | ||||||
|  | |||||||
| @ -1,10 +1,9 @@ | |||||||
| """OAuth Source Serializer""" | """OAuth Source Serializer""" | ||||||
| from django_filters.rest_framework import DjangoFilterBackend | from django_filters.rest_framework import DjangoFilterBackend | ||||||
| from rest_framework import mixins |  | ||||||
| from rest_framework.filters import OrderingFilter, SearchFilter | from rest_framework.filters import OrderingFilter, SearchFilter | ||||||
| from rest_framework.viewsets import GenericViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.authorization import OwnerFilter, OwnerPermissions | from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | ||||||
| from authentik.core.api.sources import SourceSerializer | from authentik.core.api.sources import SourceSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.sources.oauth.models import UserOAuthSourceConnection | from authentik.sources.oauth.models import UserOAuthSourceConnection | ||||||
| @ -15,30 +14,18 @@ class UserOAuthSourceConnectionSerializer(SourceSerializer): | |||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = UserOAuthSourceConnection |         model = UserOAuthSourceConnection | ||||||
|         fields = [ |         fields = ["pk", "user", "source", "identifier", "access_token"] | ||||||
|             "pk", |  | ||||||
|             "user", |  | ||||||
|             "source", |  | ||||||
|             "identifier", |  | ||||||
|         ] |  | ||||||
|         extra_kwargs = { |         extra_kwargs = { | ||||||
|             "user": {"read_only": True}, |             "access_token": {"write_only": True}, | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserOAuthSourceConnectionViewSet( | class UserOAuthSourceConnectionViewSet(UsedByMixin, ModelViewSet): | ||||||
|     mixins.RetrieveModelMixin, |  | ||||||
|     mixins.UpdateModelMixin, |  | ||||||
|     mixins.DestroyModelMixin, |  | ||||||
|     UsedByMixin, |  | ||||||
|     mixins.ListModelMixin, |  | ||||||
|     GenericViewSet, |  | ||||||
| ): |  | ||||||
|     """Source Viewset""" |     """Source Viewset""" | ||||||
|  |  | ||||||
|     queryset = UserOAuthSourceConnection.objects.all() |     queryset = UserOAuthSourceConnection.objects.all() | ||||||
|     serializer_class = UserOAuthSourceConnectionSerializer |     serializer_class = UserOAuthSourceConnectionSerializer | ||||||
|     filterset_fields = ["source__slug"] |     filterset_fields = ["source__slug"] | ||||||
|     permission_classes = [OwnerPermissions] |     permission_classes = [OwnerSuperuserPermissions] | ||||||
|     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] |     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] | ||||||
|     ordering = ["source__slug"] |     ordering = ["source__slug"] | ||||||
|  | |||||||
| @ -14,6 +14,7 @@ AUTHENTIK_SOURCES_OAUTH_TYPES = [ | |||||||
|     "authentik.sources.oauth.types.github", |     "authentik.sources.oauth.types.github", | ||||||
|     "authentik.sources.oauth.types.google", |     "authentik.sources.oauth.types.google", | ||||||
|     "authentik.sources.oauth.types.oidc", |     "authentik.sources.oauth.types.oidc", | ||||||
|  |     "authentik.sources.oauth.types.okta", | ||||||
|     "authentik.sources.oauth.types.reddit", |     "authentik.sources.oauth.types.reddit", | ||||||
|     "authentik.sources.oauth.types.twitter", |     "authentik.sources.oauth.types.twitter", | ||||||
| ] | ] | ||||||
|  | |||||||
| @ -2,13 +2,13 @@ | |||||||
| from typing import TYPE_CHECKING, Optional, Type | from typing import TYPE_CHECKING, Optional, Type | ||||||
|  |  | ||||||
| from django.db import models | from django.db import models | ||||||
|  | from django.http.request import HttpRequest | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| from authentik.core.models import Source, UserSourceConnection | from authentik.core.models import Source, UserSourceConnection | ||||||
| from authentik.core.types import UILoginButton, UserSettingSerializer | from authentik.core.types import UILoginButton, UserSettingSerializer | ||||||
| from authentik.flows.challenge import ChallengeTypes, RedirectChallenge |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from authentik.sources.oauth.types.manager import SourceType |     from authentik.sources.oauth.types.manager import SourceType | ||||||
| @ -64,24 +64,15 @@ class OAuthSource(Source): | |||||||
|  |  | ||||||
|         return OAuthSourceSerializer |         return OAuthSourceSerializer | ||||||
|  |  | ||||||
|     @property |     def ui_login_button(self, request: HttpRequest) -> UILoginButton: | ||||||
|     def ui_login_button(self) -> UILoginButton: |  | ||||||
|         provider_type = self.type |         provider_type = self.type | ||||||
|  |         provider = provider_type() | ||||||
|         return UILoginButton( |         return UILoginButton( | ||||||
|             challenge=RedirectChallenge( |  | ||||||
|                 instance={ |  | ||||||
|                     "type": ChallengeTypes.REDIRECT.value, |  | ||||||
|                     "to": reverse( |  | ||||||
|                         "authentik_sources_oauth:oauth-client-login", |  | ||||||
|                         kwargs={"source_slug": self.slug}, |  | ||||||
|                     ), |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             icon_url=provider_type().icon_url(), |  | ||||||
|             name=self.name, |             name=self.name, | ||||||
|  |             icon_url=provider.icon_url(), | ||||||
|  |             challenge=provider.login_challenge(self, request), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         return UserSettingSerializer( |         return UserSettingSerializer( | ||||||
|             data={ |             data={ | ||||||
| @ -183,6 +174,16 @@ class AppleOAuthSource(OAuthSource): | |||||||
|         verbose_name_plural = _("Apple OAuth Sources") |         verbose_name_plural = _("Apple OAuth Sources") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OktaOAuthSource(OAuthSource): | ||||||
|  |     """Login using a okta.com.""" | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |  | ||||||
|  |         abstract = True | ||||||
|  |         verbose_name = _("Okta OAuth Source") | ||||||
|  |         verbose_name_plural = _("Okta OAuth Sources") | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserOAuthSourceConnection(UserSourceConnection): | class UserOAuthSourceConnection(UserSourceConnection): | ||||||
|     """Authorized remote OAuth provider.""" |     """Authorized remote OAuth provider.""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,10 +2,15 @@ | |||||||
| from time import time | from time import time | ||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
|  |  | ||||||
|  | from django.http.request import HttpRequest | ||||||
|  | from django.urls.base import reverse | ||||||
| from jwt import decode, encode | from jwt import decode, encode | ||||||
|  | from rest_framework.fields import CharField | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes | ||||||
| from authentik.sources.oauth.clients.oauth2 import OAuth2Client | from authentik.sources.oauth.clients.oauth2 import OAuth2Client | ||||||
|  | from authentik.sources.oauth.models import OAuthSource | ||||||
| from authentik.sources.oauth.types.manager import MANAGER, SourceType | from authentik.sources.oauth.types.manager import MANAGER, SourceType | ||||||
| from authentik.sources.oauth.views.callback import OAuthCallback | from authentik.sources.oauth.views.callback import OAuthCallback | ||||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||||
| @ -13,18 +18,34 @@ from authentik.sources.oauth.views.redirect import OAuthRedirect | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AppleLoginChallenge(Challenge): | ||||||
|  |     """Special challenge for apple-native authentication flow, which happens on the client.""" | ||||||
|  |  | ||||||
|  |     client_id = CharField() | ||||||
|  |     component = CharField(default="ak-flow-sources-oauth-apple") | ||||||
|  |     scope = CharField() | ||||||
|  |     redirect_uri = CharField() | ||||||
|  |     state = CharField() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AppleChallengeResponse(ChallengeResponse): | ||||||
|  |     """Pseudo class for plex response""" | ||||||
|  |  | ||||||
|  |     component = CharField(default="ak-flow-sources-oauth-apple") | ||||||
|  |  | ||||||
|  |  | ||||||
| class AppleOAuthClient(OAuth2Client): | class AppleOAuthClient(OAuth2Client): | ||||||
|     """Apple OAuth2 client""" |     """Apple OAuth2 client""" | ||||||
|  |  | ||||||
|     def get_client_id(self) -> str: |     def get_client_id(self) -> str: | ||||||
|         parts = self.source.consumer_key.split(";") |         parts: list[str] = self.source.consumer_key.split(";") | ||||||
|         if len(parts) < 3: |         if len(parts) < 3: | ||||||
|             return self.source.consumer_key |             return self.source.consumer_key | ||||||
|         return parts[0] |         return parts[0].strip() | ||||||
|  |  | ||||||
|     def get_client_secret(self) -> str: |     def get_client_secret(self) -> str: | ||||||
|         now = time() |         now = time() | ||||||
|         parts = self.source.consumer_key.split(";") |         parts: list[str] = self.source.consumer_key.split(";") | ||||||
|         if len(parts) < 3: |         if len(parts) < 3: | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
|                 ( |                 ( | ||||||
| @ -34,14 +55,14 @@ class AppleOAuthClient(OAuth2Client): | |||||||
|             ) |             ) | ||||||
|         LOGGER.debug("got values from client_id", team=parts[1], kid=parts[2]) |         LOGGER.debug("got values from client_id", team=parts[1], kid=parts[2]) | ||||||
|         payload = { |         payload = { | ||||||
|             "iss": parts[1], |             "iss": parts[1].strip(), | ||||||
|             "iat": now, |             "iat": now, | ||||||
|             "exp": now + 86400 * 180, |             "exp": now + 86400 * 180, | ||||||
|             "aud": "https://appleid.apple.com", |             "aud": "https://appleid.apple.com", | ||||||
|             "sub": parts[0], |             "sub": parts[0].strip(), | ||||||
|         } |         } | ||||||
|         # pyright: reportGeneralTypeIssues=false |         # pyright: reportGeneralTypeIssues=false | ||||||
|         jwt = encode(payload, self.source.consumer_secret, "ES256", {"kid": parts[2]}) |         jwt = encode(payload, self.source.consumer_secret, "ES256", {"kid": parts[2].strip()}) | ||||||
|         LOGGER.debug("signing payload as secret key", payload=payload, jwt=jwt) |         LOGGER.debug("signing payload as secret key", payload=payload, jwt=jwt) | ||||||
|         return jwt |         return jwt | ||||||
|  |  | ||||||
| @ -55,7 +76,7 @@ class AppleOAuthRedirect(OAuthRedirect): | |||||||
|  |  | ||||||
|     client_class = AppleOAuthClient |     client_class = AppleOAuthClient | ||||||
|  |  | ||||||
|     def get_additional_parameters(self, source):  # pragma: no cover |     def get_additional_parameters(self, source: OAuthSource):  # pragma: no cover | ||||||
|         return { |         return { | ||||||
|             "scope": "name email", |             "scope": "name email", | ||||||
|             "response_mode": "form_post", |             "response_mode": "form_post", | ||||||
| @ -74,7 +95,6 @@ class AppleOAuth2Callback(OAuthCallback): | |||||||
|         self, |         self, | ||||||
|         info: dict[str, Any], |         info: dict[str, Any], | ||||||
|     ) -> dict[str, Any]: |     ) -> dict[str, Any]: | ||||||
|         print(info) |  | ||||||
|         return { |         return { | ||||||
|             "email": info.get("email"), |             "email": info.get("email"), | ||||||
|             "name": info.get("name"), |             "name": info.get("name"), | ||||||
| @ -96,3 +116,24 @@ class AppleType(SourceType): | |||||||
|  |  | ||||||
|     def icon_url(self) -> str: |     def icon_url(self) -> str: | ||||||
|         return "https://appleid.cdn-apple.com/appleid/button/logo" |         return "https://appleid.cdn-apple.com/appleid/button/logo" | ||||||
|  |  | ||||||
|  |     def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge: | ||||||
|  |         """Pre-general all the things required for the JS SDK""" | ||||||
|  |         apple_client = AppleOAuthClient( | ||||||
|  |             source, | ||||||
|  |             request, | ||||||
|  |             callback=reverse( | ||||||
|  |                 "authentik_sources_oauth:oauth-client-callback", | ||||||
|  |                 kwargs={"source_slug": source.slug}, | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|  |         args = apple_client.get_redirect_args() | ||||||
|  |         return AppleLoginChallenge( | ||||||
|  |             instance={ | ||||||
|  |                 "client_id": apple_client.get_client_id(), | ||||||
|  |                 "scope": "name email", | ||||||
|  |                 "redirect_uri": args["redirect_uri"], | ||||||
|  |                 "state": args["state"], | ||||||
|  |                 "type": ChallengeTypes.NATIVE.value, | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  | |||||||
| @ -2,9 +2,13 @@ | |||||||
| from enum import Enum | from enum import Enum | ||||||
| from typing import Callable, Optional, Type | from typing import Callable, Optional, Type | ||||||
|  |  | ||||||
|  | from django.http.request import HttpRequest | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
|  | from django.urls.base import reverse | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.flows.challenge import Challenge, ChallengeTypes, RedirectChallenge | ||||||
|  | from authentik.sources.oauth.models import OAuthSource | ||||||
| from authentik.sources.oauth.views.callback import OAuthCallback | from authentik.sources.oauth.views.callback import OAuthCallback | ||||||
| from authentik.sources.oauth.views.redirect import OAuthRedirect | from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||||
|  |  | ||||||
| @ -37,6 +41,19 @@ class SourceType: | |||||||
|         """Get Icon URL for login""" |         """Get Icon URL for login""" | ||||||
|         return static(f"authentik/sources/{self.slug}.svg") |         return static(f"authentik/sources/{self.slug}.svg") | ||||||
|  |  | ||||||
|  |     # pylint: disable=unused-argument | ||||||
|  |     def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge: | ||||||
|  |         """Allow types to return custom challenges""" | ||||||
|  |         return RedirectChallenge( | ||||||
|  |             instance={ | ||||||
|  |                 "type": ChallengeTypes.REDIRECT.value, | ||||||
|  |                 "to": reverse( | ||||||
|  |                     "authentik_sources_oauth:oauth-client-login", | ||||||
|  |                     kwargs={"source_slug": source.slug}, | ||||||
|  |                 ), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SourceTypeManager: | class SourceTypeManager: | ||||||
|     """Manager to hold all Source types.""" |     """Manager to hold all Source types.""" | ||||||
|  | |||||||
							
								
								
									
										51
									
								
								authentik/sources/oauth/types/okta.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								authentik/sources/oauth/types/okta.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,51 @@ | |||||||
|  | """Okta OAuth Views""" | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
|  | from authentik.sources.oauth.models import OAuthSource | ||||||
|  | from authentik.sources.oauth.types.azure_ad import AzureADClient | ||||||
|  | from authentik.sources.oauth.types.manager import MANAGER, SourceType | ||||||
|  | from authentik.sources.oauth.views.callback import OAuthCallback | ||||||
|  | from authentik.sources.oauth.views.redirect import OAuthRedirect | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OktaOAuthRedirect(OAuthRedirect): | ||||||
|  |     """Okta OAuth2 Redirect""" | ||||||
|  |  | ||||||
|  |     def get_additional_parameters(self, source: OAuthSource):  # pragma: no cover | ||||||
|  |         return { | ||||||
|  |             "scope": "openid email profile", | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OktaOAuth2Callback(OAuthCallback): | ||||||
|  |     """Okta OAuth2 Callback""" | ||||||
|  |  | ||||||
|  |     # Okta has the same quirk as azure and throws an error if the access token | ||||||
|  |     # is set via query parameter, so we re-use the azure client | ||||||
|  |     # see https://github.com/goauthentik/authentik/issues/1910 | ||||||
|  |     client_class = AzureADClient | ||||||
|  |  | ||||||
|  |     def get_user_id(self, info: dict[str, str]) -> str: | ||||||
|  |         return info.get("sub", "") | ||||||
|  |  | ||||||
|  |     def get_user_enroll_context( | ||||||
|  |         self, | ||||||
|  |         info: dict[str, Any], | ||||||
|  |     ) -> dict[str, Any]: | ||||||
|  |         return { | ||||||
|  |             "username": info.get("nickname"), | ||||||
|  |             "email": info.get("email"), | ||||||
|  |             "name": info.get("name"), | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @MANAGER.type() | ||||||
|  | class OktaType(SourceType): | ||||||
|  |     """Okta Type definition""" | ||||||
|  |  | ||||||
|  |     callback_view = OktaOAuth2Callback | ||||||
|  |     redirect_view = OktaOAuthRedirect | ||||||
|  |     name = "Okta" | ||||||
|  |     slug = "okta" | ||||||
|  |  | ||||||
|  |     urls_customizable = True | ||||||
| @ -1,10 +1,9 @@ | |||||||
| """Plex Source connection Serializer""" | """Plex Source connection Serializer""" | ||||||
| from django_filters.rest_framework import DjangoFilterBackend | from django_filters.rest_framework import DjangoFilterBackend | ||||||
| from rest_framework import mixins |  | ||||||
| from rest_framework.filters import OrderingFilter, SearchFilter | from rest_framework.filters import OrderingFilter, SearchFilter | ||||||
| from rest_framework.viewsets import GenericViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.authorization import OwnerFilter, OwnerPermissions | from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | ||||||
| from authentik.core.api.sources import SourceSerializer | from authentik.core.api.sources import SourceSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.sources.plex.models import PlexSourceConnection | from authentik.sources.plex.models import PlexSourceConnection | ||||||
| @ -27,19 +26,12 @@ class PlexSourceConnectionSerializer(SourceSerializer): | |||||||
|         } |         } | ||||||
|  |  | ||||||
|  |  | ||||||
| class PlexSourceConnectionViewSet( | class PlexSourceConnectionViewSet(UsedByMixin, ModelViewSet): | ||||||
|     mixins.RetrieveModelMixin, |  | ||||||
|     mixins.UpdateModelMixin, |  | ||||||
|     mixins.DestroyModelMixin, |  | ||||||
|     UsedByMixin, |  | ||||||
|     mixins.ListModelMixin, |  | ||||||
|     GenericViewSet, |  | ||||||
| ): |  | ||||||
|     """Plex Source connection Serializer""" |     """Plex Source connection Serializer""" | ||||||
|  |  | ||||||
|     queryset = PlexSourceConnection.objects.all() |     queryset = PlexSourceConnection.objects.all() | ||||||
|     serializer_class = PlexSourceConnectionSerializer |     serializer_class = PlexSourceConnectionSerializer | ||||||
|     filterset_fields = ["source__slug"] |     filterset_fields = ["source__slug"] | ||||||
|     permission_classes = [OwnerPermissions] |     permission_classes = [OwnerSuperuserPermissions] | ||||||
|     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] |     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] | ||||||
|     ordering = ["pk"] |     ordering = ["pk"] | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ from typing import Optional | |||||||
|  |  | ||||||
| from django.contrib.postgres.fields import ArrayField | from django.contrib.postgres.fields import ArrayField | ||||||
| from django.db import models | from django.db import models | ||||||
|  | from django.http.request import HttpRequest | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from rest_framework.fields import CharField | from rest_framework.fields import CharField | ||||||
| @ -62,8 +63,7 @@ class PlexSource(Source): | |||||||
|  |  | ||||||
|         return PlexSourceSerializer |         return PlexSourceSerializer | ||||||
|  |  | ||||||
|     @property |     def ui_login_button(self, request: HttpRequest) -> UILoginButton: | ||||||
|     def ui_login_button(self) -> UILoginButton: |  | ||||||
|         return UILoginButton( |         return UILoginButton( | ||||||
|             challenge=PlexAuthenticationChallenge( |             challenge=PlexAuthenticationChallenge( | ||||||
|                 { |                 { | ||||||
| @ -77,7 +77,6 @@ class PlexSource(Source): | |||||||
|             name=self.name, |             name=self.name, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         return UserSettingSerializer( |         return UserSettingSerializer( | ||||||
|             data={ |             data={ | ||||||
|  | |||||||
| @ -167,8 +167,7 @@ class SAMLSource(Source): | |||||||
|             reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug}) |             reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug}) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @property |     def ui_login_button(self, request: HttpRequest) -> UILoginButton: | ||||||
|     def ui_login_button(self) -> UILoginButton: |  | ||||||
|         return UILoginButton( |         return UILoginButton( | ||||||
|             challenge=RedirectChallenge( |             challenge=RedirectChallenge( | ||||||
|                 instance={ |                 instance={ | ||||||
|  | |||||||
| @ -3,7 +3,12 @@ from django.utils.timezone import now | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession, User | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.events.monitored_tasks import PrefilledMonitoredTask, TaskResult, TaskResultStatus | from authentik.events.monitored_tasks import ( | ||||||
|  |     MonitoredTask, | ||||||
|  |     TaskResult, | ||||||
|  |     TaskResultStatus, | ||||||
|  |     prefill_task, | ||||||
|  | ) | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
| from authentik.sources.saml.models import SAMLSource | from authentik.sources.saml.models import SAMLSource | ||||||
| @ -11,8 +16,9 @@ from authentik.sources.saml.models import SAMLSource | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=PrefilledMonitoredTask) | @CELERY_APP.task(bind=True, base=MonitoredTask) | ||||||
| def clean_temporary_users(self: PrefilledMonitoredTask): | @prefill_task | ||||||
|  | def clean_temporary_users(self: MonitoredTask): | ||||||
|     """Remove temporary users created by SAML Sources""" |     """Remove temporary users created by SAML Sources""" | ||||||
|     _now = now() |     _now = now() | ||||||
|     messages = [] |     messages = [] | ||||||
|  | |||||||
| @ -48,7 +48,6 @@ class AuthenticatorDuoStage(ConfigurableStage, Stage): | |||||||
|     def component(self) -> str: |     def component(self) -> str: | ||||||
|         return "ak-stage-authenticator-duo-form" |         return "ak-stage-authenticator-duo-form" | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         return UserSettingSerializer( |         return UserSettingSerializer( | ||||||
|             data={ |             data={ | ||||||
|  | |||||||
| @ -141,7 +141,6 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage): | |||||||
|     def component(self) -> str: |     def component(self) -> str: | ||||||
|         return "ak-stage-authenticator-sms-form" |         return "ak-stage-authenticator-sms-form" | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         return UserSettingSerializer( |         return UserSettingSerializer( | ||||||
|             data={ |             data={ | ||||||
|  | |||||||
| @ -90,6 +90,5 @@ class AuthenticatorSMSStageTests(APITestCase): | |||||||
|                     "code": int(self.client.session[SESSION_SMS_DEVICE].token), |                     "code": int(self.client.session[SESSION_SMS_DEVICE].token), | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             print(response.content) |  | ||||||
|             self.assertEqual(response.status_code, 200) |             self.assertEqual(response.status_code, 200) | ||||||
|             sms_send_mock.assert_not_called() |             sms_send_mock.assert_not_called() | ||||||
|  | |||||||
| @ -31,7 +31,6 @@ class AuthenticatorStaticStage(ConfigurableStage, Stage): | |||||||
|     def component(self) -> str: |     def component(self) -> str: | ||||||
|         return "ak-stage-authenticator-static-form" |         return "ak-stage-authenticator-static-form" | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         return UserSettingSerializer( |         return UserSettingSerializer( | ||||||
|             data={ |             data={ | ||||||
|  | |||||||
| @ -38,7 +38,6 @@ class AuthenticatorTOTPStage(ConfigurableStage, Stage): | |||||||
|     def component(self) -> str: |     def component(self) -> str: | ||||||
|         return "ak-stage-authenticator-totp-form" |         return "ak-stage-authenticator-totp-form" | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def ui_user_settings(self) -> Optional[UserSettingSerializer]: |     def ui_user_settings(self) -> Optional[UserSettingSerializer]: | ||||||
|         return UserSettingSerializer( |         return UserSettingSerializer( | ||||||
|             data={ |             data={ | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ class AuthenticateWebAuthnStageSerializer(StageSerializer): | |||||||
|     class Meta: |     class Meta: | ||||||
|  |  | ||||||
|         model = AuthenticateWebAuthnStage |         model = AuthenticateWebAuthnStage | ||||||
|         fields = StageSerializer.Meta.fields + ["configure_flow"] |         fields = StageSerializer.Meta.fields + ["configure_flow", "user_verification"] | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticateWebAuthnStageViewSet(UsedByMixin, ModelViewSet): | class AuthenticateWebAuthnStageViewSet(UsedByMixin, ModelViewSet): | ||||||
|  | |||||||
| @ -0,0 +1,25 @@ | |||||||
|  | # Generated by Django 4.0 on 2021-12-14 09:05 | ||||||
|  |  | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_stages_authenticator_webauthn", "0004_auto_20210304_1850"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="authenticatewebauthnstage", | ||||||
|  |             name="user_verification", | ||||||
|  |             field=models.TextField( | ||||||
|  |                 choices=[ | ||||||
|  |                     ("required", "Required"), | ||||||
|  |                     ("preferred", "Preferred"), | ||||||
|  |                     ("discouraged", "Discouraged"), | ||||||
|  |                 ], | ||||||
|  |                 default="preferred", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	