Compare commits
	
		
			247 Commits
		
	
	
		
			version/20
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 2cfba36cb7 | |||
| 73bfe19f0e | |||
| 64e211c3a9 | |||
| 3c354db858 | |||
| 99df72d944 | |||
| 374b10b6e5 | |||
| 4c606fb0ba | |||
| f8502edd2b | |||
| aa48f6dd9d | |||
| 49b6aabb02 | |||
| f9d9b2716d | |||
| 7932b390dc | |||
| 241e36b2a6 | |||
| 81e820b6e6 | |||
| b16a3d5697 | |||
| 1583d53e54 | |||
| 909a7772dc | |||
| cd42b013ca | |||
| 22a2e65d30 | |||
| f3fe2c1b4a | |||
| 133accd033 | |||
| 47daaf969a | |||
| 9fb5092fdc | |||
| 05ccff4651 | |||
| 2d965afc5f | |||
| 522abfd2fd | |||
| 2d1bcf1aa7 | |||
| 24dbfcea3f | |||
| f008c03524 | |||
| 9ab066a863 | |||
| daa0417c38 | |||
| 067166d420 | |||
| 2bd10dbdee | |||
| 09795fa6fb | |||
| be64296494 | |||
| 02e2c117ac | |||
| 75434cc23f | |||
| 778e316690 | |||
| 3e0778fe31 | |||
| d2390eef89 | |||
| 5c542d5dc2 | |||
| 6b5b72ab4c | |||
| 93ae3c19b0 | |||
| a2ccdaca05 | |||
| 4a0e051c0b | |||
| 7ad2992fe7 | |||
| 223000804e | |||
| 7b0cb38c4b | |||
| 94d8465e92 | |||
| 4a91a7d2e2 | |||
| 369440652c | |||
| 493cdd5c0f | |||
| 9f5c019daa | |||
| ab28370f20 | |||
| 84c08dca41 | |||
| 6b8b596c92 | |||
| dc622a836f | |||
| 7b4faa0170 | |||
| 65f5f21de2 | |||
| d7e2b2e8a0 | |||
| c12c3877f6 | |||
| 9cc1b1213f | |||
| d0f5e77f77 | |||
| 68950ee8b7 | |||
| 53f224300b | |||
| 73019c0732 | |||
| 359da6db81 | |||
| 34928572db | |||
| 7f8afad528 | |||
| c1ad1e5c8b | |||
| b35b225453 | |||
| 0ff2ac7dc2 | |||
| 8b4a7666f0 | |||
| e477615b0f | |||
| e99b90912f | |||
| ae9dbf3014 | |||
| 7a50d5a4f8 | |||
| 4c4d87d3bd | |||
| a407334d3b | |||
| 5026cebf02 | |||
| 9770ba07c2 | |||
| 2e2ab55f9e | |||
| 28835fbca7 | |||
| aabb8af486 | |||
| 371f1878a8 | |||
| 662bd8af96 | |||
| 1b57cc3bf0 | |||
| 7517d612d0 | |||
| a29fabac42 | |||
| a0ff4ac038 | |||
| 4f6e3516b9 | |||
| 220f123b29 | |||
| 3e70b6443a | |||
| 792531a968 | |||
| de17fdb3ff | |||
| 19fdcba308 | |||
| 62f93c83d4 | |||
| 03a3f1bd6f | |||
| c2cb804ace | |||
| 11334cf638 | |||
| 60266b3345 | |||
| 2a4679e390 | |||
| 34e71351a6 | |||
| f13af32102 | |||
| 832b5f1058 | |||
| df789265f9 | |||
| b45fbbf20f | |||
| b4c500ce15 | |||
| 114226dc22 | |||
| 897446e5ac | |||
| eed958b132 | |||
| d0a69557d4 | |||
| 712d2b2aee | |||
| 2293cecd24 | |||
| 4e8f0bfb66 | |||
| 578139811c | |||
| f806dea0f6 | |||
| 02ec42738c | |||
| 86468163f0 | |||
| 159e533e4a | |||
| 24ee2c6c05 | |||
| dd383d763f | |||
| df9d8e9d25 | |||
| 12c318f0b1 | |||
| fc6ed8e7f9 | |||
| f68ed3562e | |||
| 621245aece | |||
| f2f22719f8 | |||
| 242423cf3c | |||
| 8e7a456f74 | |||
| 3987f8e371 | |||
| 6fb531c482 | |||
| 159798a7d8 | |||
| 89f8962a23 | |||
| 3db0a5b3d1 | |||
| feb20c371c | |||
| 73081e4947 | |||
| d7c2d0c7f9 | |||
| d5c57ab251 | |||
| 5511458757 | |||
| d9775f2822 | |||
| 398eb23d31 | |||
| 14a7c9f967 | |||
| 3e11f0c0b3 | |||
| c055245a45 | |||
| b1ba8f60a1 | |||
| f8f37dc52c | |||
| 19c36d20b5 | |||
| abca435337 | |||
| b5ee81f4de | |||
| 13bb7682b2 | |||
| 99f0f556e1 | |||
| 54ba3e9616 | |||
| 58e3ca28be | |||
| c6bb41890e | |||
| a4556b3692 | |||
| d0b52812d5 | |||
| f7e689ff03 | |||
| 2c419bee09 | |||
| 3c4c1e9d65 | |||
| ef3be13fdb | |||
| d3466ceef8 | |||
| 5886688fae | |||
| c3c8cbf7ef | |||
| 251ab71c3a | |||
| a0c546023f | |||
| 83eaac375d | |||
| 2868331976 | |||
| 5c0986c538 | |||
| ac21dcc44f | |||
| e9e7e25b27 | |||
| 7dc7d51cfa | |||
| 3eb3a9eab9 | |||
| 30db3b543b | |||
| 11aee98e0b | |||
| a099b21671 | |||
| b9294fd9ad | |||
| 13a302cdad | |||
| e994a01e80 | |||
| d49431cfc7 | |||
| ce2ce38b59 | |||
| 2af4f28239 | |||
| 1419910b29 | |||
| 649835cc61 | |||
| 917c4ae835 | |||
| ca2fce8be2 | |||
| e70fcd1a6f | |||
| fd461d9a00 | |||
| 8ea45b4dfe | |||
| 514d8f4569 | |||
| 2a43326da5 | |||
| 06f4c0608f | |||
| dc8b8e8f13 | |||
| 9f172e7ad0 | |||
| 832d76ec2a | |||
| 2545d85e8e | |||
| 9264acd00e | |||
| 51dd9473ce | |||
| 15c34c6f1f | |||
| 0ab8f4eed7 | |||
| 070714abe4 | |||
| 810c04bacf | |||
| b624bf1cb7 | |||
| f56a4b00ce | |||
| 6ec0411930 | |||
| fb59969bce | |||
| eec9c46533 | |||
| a3afd47850 | |||
| 8ffae4505f | |||
| 0cc83c23c4 | |||
| 514c48a986 | |||
| fdb8fb4b4c | |||
| d8a68407f9 | |||
| 417156d659 | |||
| 9d58407e25 | |||
| f4441c9fcf | |||
| 0e9762072a | |||
| 0cfffa28ad | |||
| 1ad4c8fc29 | |||
| dd2facdc57 | |||
| 549dfa4c3a | |||
| 7f8ae24e8d | |||
| d05aeb91f2 | |||
| ab25216c1f | |||
| fb5eb7b868 | |||
| bda218f7fc | |||
| 198c940a80 | |||
| c900411d5a | |||
| 1adc6948b4 | |||
| e87236b285 | |||
| 846b63a17b | |||
| 48fb3c98ff | |||
| b488a6fec9 | |||
| f014bd5f30 | |||
| 03dd079e8c | |||
| 1281e842d1 | |||
| f7601d9571 | |||
| 4d9c9160e7 | |||
| ad1f913e54 | |||
| 3da0233c40 | |||
| ff788edd9b | |||
| aea0958f3f | |||
| 1f9e9f9ca0 | |||
| 98ffec87c0 | |||
| b52d5dccac | |||
| e96a4fa181 | |||
| c53b0830c4 | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2022.8.1 | current_version = 2022.9.0 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) | ||||||
| @ -17,4 +17,4 @@ tag_name = version/{new_version} | |||||||
|  |  | ||||||
| [bumpversion:file:internal/constants/constants.go] | [bumpversion:file:internal/constants/constants.go] | ||||||
|  |  | ||||||
| [bumpversion:file:web/src/constants.ts] | [bumpversion:file:web/src/common/constants.ts] | ||||||
|  | |||||||
| @ -11,38 +11,7 @@ runs: | |||||||
|   steps: |   steps: | ||||||
|     - name: Generate config |     - name: Generate config | ||||||
|       id: ev |       id: ev | ||||||
|       shell: python |       uses: ./.github/actions/docker-push-variables | ||||||
|       run: | |  | ||||||
|         """Helper script to get the actual branch name, docker safe""" |  | ||||||
|         import os |  | ||||||
|         from time import time |  | ||||||
|  |  | ||||||
|         env_pr_branch = "GITHUB_HEAD_REF" |  | ||||||
|         default_branch = "GITHUB_REF" |  | ||||||
|         sha = "GITHUB_SHA" |  | ||||||
|  |  | ||||||
|         branch_name = os.environ[default_branch] |  | ||||||
|         if os.environ.get(env_pr_branch, "") != "": |  | ||||||
|             branch_name = os.environ[env_pr_branch] |  | ||||||
|  |  | ||||||
|         should_build = str(os.environ.get("DOCKER_USERNAME", "") != "").lower() |  | ||||||
|  |  | ||||||
|         print("##[set-output name=branchName]%s" % branch_name) |  | ||||||
|         print( |  | ||||||
|             "##[set-output name=branchNameContainer]%s" |  | ||||||
|             % branch_name.replace("refs/heads/", "").replace("/", "-") |  | ||||||
|         ) |  | ||||||
|         print("##[set-output name=timestamp]%s" % int(time())) |  | ||||||
|         print("##[set-output name=sha]%s" % os.environ[sha]) |  | ||||||
|         print("##[set-output name=shouldBuild]%s" % should_build) |  | ||||||
|  |  | ||||||
|         import configparser |  | ||||||
|         parser = configparser.ConfigParser() |  | ||||||
|         parser.read(".bumpversion.cfg") |  | ||||||
|         version = parser.get("bumpversion", "current_version") |  | ||||||
|         version_family = ".".join(version.split(".")[:-1]) |  | ||||||
|         print("##[set-output name=version]%s" % version) |  | ||||||
|         print("##[set-output name=versionFamily]%s" % version_family) |  | ||||||
|     - name: Find Comment |     - name: Find Comment | ||||||
|       uses: peter-evans/find-comment@v2 |       uses: peter-evans/find-comment@v2 | ||||||
|       id: fc |       id: fc | ||||||
| @ -83,8 +52,6 @@ runs: | |||||||
|             image: |             image: | ||||||
|                 repository: ghcr.io/goauthentik/dev-server |                 repository: ghcr.io/goauthentik/dev-server | ||||||
|                 tag: ${{ inputs.tag }} |                 tag: ${{ inputs.tag }} | ||||||
|                 # pullPolicy: Always to ensure you always get the latest version |  | ||||||
|                 pullPolicy: Always |  | ||||||
|             ``` |             ``` | ||||||
|  |  | ||||||
|             Afterwards, run the upgrade commands from the latest release notes. |             Afterwards, run the upgrade commands from the latest release notes. | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -27,7 +27,7 @@ runs: | |||||||
|         docker-compose -f .github/actions/setup/docker-compose.yml up -d |         docker-compose -f .github/actions/setup/docker-compose.yml up -d | ||||||
|         poetry env use python3.10 |         poetry env use python3.10 | ||||||
|         poetry install |         poetry install | ||||||
|         npm install -g pyright@1.1.136 |         cd web && npm ci | ||||||
|     - name: Generate config |     - name: Generate config | ||||||
|       shell: poetry run python {0} |       shell: poetry run python {0} | ||||||
|       run: | |       run: | | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | |||||||
| keypair | keypair | ||||||
| keypairs | keypairs | ||||||
| hass | hass | ||||||
|  | warmup | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -96,6 +96,8 @@ jobs: | |||||||
|           testspace [unittest]unittest.xml --link=codecov |           testspace [unittest]unittest.xml --link=codecov | ||||||
|       - if: ${{ always() }} |       - if: ${{ always() }} | ||||||
|         uses: codecov/codecov-action@v3 |         uses: codecov/codecov-action@v3 | ||||||
|  |         with: | ||||||
|  |           flags: unit | ||||||
|   test-integration: |   test-integration: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -117,6 +119,8 @@ jobs: | |||||||
|           testspace [integration]unittest.xml --link=codecov |           testspace [integration]unittest.xml --link=codecov | ||||||
|       - if: ${{ always() }} |       - if: ${{ always() }} | ||||||
|         uses: codecov/codecov-action@v3 |         uses: codecov/codecov-action@v3 | ||||||
|  |         with: | ||||||
|  |           flags: integration | ||||||
|   test-e2e-provider: |   test-e2e-provider: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -139,7 +143,7 @@ jobs: | |||||||
|         working-directory: web |         working-directory: web | ||||||
|         run: | |         run: | | ||||||
|           npm ci |           npm ci | ||||||
|           make -C .. gen-client-web |           make -C .. gen-client-ts | ||||||
|           npm run build |           npm run build | ||||||
|       - name: run e2e |       - name: run e2e | ||||||
|         run: | |         run: | | ||||||
| @ -151,6 +155,8 @@ jobs: | |||||||
|           testspace [e2e-provider]unittest.xml --link=codecov |           testspace [e2e-provider]unittest.xml --link=codecov | ||||||
|       - if: ${{ always() }} |       - if: ${{ always() }} | ||||||
|         uses: codecov/codecov-action@v3 |         uses: codecov/codecov-action@v3 | ||||||
|  |         with: | ||||||
|  |           flags: e2e | ||||||
|   test-e2e-rest: |   test-e2e-rest: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -173,7 +179,7 @@ jobs: | |||||||
|         working-directory: web/ |         working-directory: web/ | ||||||
|         run: | |         run: | | ||||||
|           npm ci |           npm ci | ||||||
|           make -C .. gen-client-web |           make -C .. gen-client-ts | ||||||
|           npm run build |           npm run build | ||||||
|       - name: run e2e |       - name: run e2e | ||||||
|         run: | |         run: | | ||||||
| @ -185,6 +191,8 @@ jobs: | |||||||
|           testspace [e2e-rest]unittest.xml --link=codecov |           testspace [e2e-rest]unittest.xml --link=codecov | ||||||
|       - if: ${{ always() }} |       - if: ${{ always() }} | ||||||
|         uses: codecov/codecov-action@v3 |         uses: codecov/codecov-action@v3 | ||||||
|  |         with: | ||||||
|  |           flags: e2e | ||||||
|   ci-core-mark: |   ci-core-mark: | ||||||
|     needs: |     needs: | ||||||
|       - lint |       - lint | ||||||
| @ -240,4 +248,4 @@ jobs: | |||||||
|         continue-on-error: true |         continue-on-error: true | ||||||
|         uses: ./.github/actions/comment-pr-instructions |         uses: ./.github/actions/comment-pr-instructions | ||||||
|         with: |         with: | ||||||
|           tag: gh-${{ steps.ev.outputs.branchNameContainer }} |           tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.sha }} | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							| @ -23,7 +23,7 @@ jobs: | |||||||
|       - working-directory: web/ |       - working-directory: web/ | ||||||
|         run: npm ci |         run: npm ci | ||||||
|       - name: Generate API |       - name: Generate API | ||||||
|         run: make gen-client-web |         run: make gen-client-ts | ||||||
|       - name: Eslint |       - name: Eslint | ||||||
|         working-directory: web/ |         working-directory: web/ | ||||||
|         run: npm run lint |         run: npm run lint | ||||||
| @ -39,7 +39,7 @@ jobs: | |||||||
|       - working-directory: web/ |       - working-directory: web/ | ||||||
|         run: npm ci |         run: npm ci | ||||||
|       - name: Generate API |       - name: Generate API | ||||||
|         run: make gen-client-web |         run: make gen-client-ts | ||||||
|       - name: prettier |       - name: prettier | ||||||
|         working-directory: web/ |         working-directory: web/ | ||||||
|         run: npm run prettier-check |         run: npm run prettier-check | ||||||
| @ -60,7 +60,7 @@ jobs: | |||||||
|           cd node_modules/@goauthentik |           cd node_modules/@goauthentik | ||||||
|           ln -s ../../src/ web |           ln -s ../../src/ web | ||||||
|       - name: Generate API |       - name: Generate API | ||||||
|         run: make gen-client-web |         run: make gen-client-ts | ||||||
|       - name: lit-analyse |       - name: lit-analyse | ||||||
|         working-directory: web/ |         working-directory: web/ | ||||||
|         run: npm run lit-analyse |         run: npm run lit-analyse | ||||||
| @ -86,7 +86,7 @@ jobs: | |||||||
|       - working-directory: web/ |       - working-directory: web/ | ||||||
|         run: npm ci |         run: npm ci | ||||||
|       - name: Generate API |       - name: Generate API | ||||||
|         run: make gen-client-web |         run: make gen-client-ts | ||||||
|       - name: build |       - name: build | ||||||
|         working-directory: web/ |         working-directory: web/ | ||||||
|         run: npm run build |         run: npm run build | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -138,7 +138,7 @@ jobs: | |||||||
|           docker-compose pull -q |           docker-compose pull -q | ||||||
|           docker-compose up --no-start |           docker-compose up --no-start | ||||||
|           docker-compose start postgresql redis |           docker-compose start postgresql redis | ||||||
|           docker-compose run -u root server test |           docker-compose run -u root server test-all | ||||||
|   sentry-release: |   sentry-release: | ||||||
|     needs: |     needs: | ||||||
|       - build-server |       - build-server | ||||||
| @ -157,6 +157,7 @@ jobs: | |||||||
|           docker cp ${container}:web/ . |           docker cp ${container}:web/ . | ||||||
|       - name: Create a Sentry.io release |       - name: Create a Sentry.io release | ||||||
|         uses: getsentry/action-release@v1 |         uses: getsentry/action-release@v1 | ||||||
|  |         continue-on-error: true | ||||||
|         if: ${{ github.event_name == 'release' }} |         if: ${{ github.event_name == 'release' }} | ||||||
|         env: |         env: | ||||||
|           SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} |           SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.github/workflows/web-api-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/web-api-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -10,13 +10,12 @@ jobs: | |||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v3 |       - uses: actions/checkout@v3 | ||||||
|       # Setup .npmrc file to publish to npm |  | ||||||
|       - uses: actions/setup-node@v3.4.1 |       - uses: actions/setup-node@v3.4.1 | ||||||
|         with: |         with: | ||||||
|           node-version: '16' |           node-version: '16' | ||||||
|           registry-url: 'https://registry.npmjs.org' |           registry-url: 'https://registry.npmjs.org' | ||||||
|       - name: Generate API Client |       - name: Generate API Client | ||||||
|         run: make gen-client-web |         run: make gen-client-ts | ||||||
|       - name: Publish package |       - name: Publish package | ||||||
|         working-directory: gen-ts-api/ |         working-directory: gen-ts-api/ | ||||||
|         run: | |         run: | | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ WORKDIR /work/web | |||||||
| RUN npm ci && npm run build | RUN npm ci && npm run build | ||||||
|  |  | ||||||
| # Stage 3: Poetry to requirements.txt export | # Stage 3: Poetry to requirements.txt export | ||||||
| FROM docker.io/python:3.10.6-slim-bullseye AS poetry-locker | FROM docker.io/python:3.10.7-slim-bullseye AS poetry-locker | ||||||
|  |  | ||||||
| WORKDIR /work | WORKDIR /work | ||||||
| COPY ./pyproject.toml /work | COPY ./pyproject.toml /work | ||||||
| @ -30,7 +30,7 @@ RUN pip install --no-cache-dir poetry && \ | |||||||
|     poetry export -f requirements.txt --dev --output requirements-dev.txt |     poetry export -f requirements.txt --dev --output requirements-dev.txt | ||||||
|  |  | ||||||
| # Stage 4: Build go proxy | # Stage 4: Build go proxy | ||||||
| FROM docker.io/golang:1.19.0-bullseye AS go-builder | FROM docker.io/golang:1.19.1-bullseye AS go-builder | ||||||
|  |  | ||||||
| WORKDIR /work | WORKDIR /work | ||||||
|  |  | ||||||
| @ -46,7 +46,7 @@ COPY ./go.sum /work/go.sum | |||||||
| RUN go build -o /work/authentik ./cmd/server/main.go | RUN go build -o /work/authentik ./cmd/server/main.go | ||||||
|  |  | ||||||
| # Stage 5: Run | # Stage 5: Run | ||||||
| FROM docker.io/python:3.10.6-slim-bullseye AS final-image | FROM docker.io/python:3.10.7-slim-bullseye AS final-image | ||||||
|  |  | ||||||
| LABEL org.opencontainers.image.url https://goauthentik.io | LABEL org.opencontainers.image.url https://goauthentik.io | ||||||
| LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info. | LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info. | ||||||
|  | |||||||
							
								
								
									
										52
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								Makefile
									
									
									
									
									
								
							| @ -28,7 +28,7 @@ test-docker: | |||||||
| 	rm -f .env | 	rm -f .env | ||||||
|  |  | ||||||
| test: | test: | ||||||
| 	coverage run manage.py test authentik | 	coverage run manage.py test --keepdb authentik | ||||||
| 	coverage html | 	coverage html | ||||||
| 	coverage report | 	coverage report | ||||||
|  |  | ||||||
| @ -49,28 +49,50 @@ lint: | |||||||
| 	bandit -r authentik tests lifecycle -x node_modules | 	bandit -r authentik tests lifecycle -x node_modules | ||||||
| 	golangci-lint run -v | 	golangci-lint run -v | ||||||
|  |  | ||||||
|  | migrate: | ||||||
|  | 	python -m lifecycle.migrate | ||||||
|  |  | ||||||
|  | run: | ||||||
|  | 	go run -v cmd/server/main.go | ||||||
|  |  | ||||||
| i18n-extract: i18n-extract-core web-extract | i18n-extract: i18n-extract-core web-extract | ||||||
|  |  | ||||||
| i18n-extract-core: | i18n-extract-core: | ||||||
| 	ak makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en | 	ak makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en | ||||||
|  |  | ||||||
|  | ######################### | ||||||
|  | ## API Schema | ||||||
|  | ######################### | ||||||
|  |  | ||||||
| gen-build: | gen-build: | ||||||
| 	AUTHENTIK_DEBUG=true ak make_blueprint_schema > blueprints/schema.json | 	AUTHENTIK_DEBUG=true ak make_blueprint_schema > blueprints/schema.json | ||||||
| 	AUTHENTIK_DEBUG=true ak spectacular --file schema.yml | 	AUTHENTIK_DEBUG=true ak spectacular --file schema.yml | ||||||
|  |  | ||||||
|  | gen-diff: | ||||||
|  | 	git show $(shell git tag -l | tail -n 1):schema.yml > old_schema.yml | ||||||
|  | 	docker run \ | ||||||
|  | 		--rm -v ${PWD}:/local \ | ||||||
|  | 		--user ${UID}:${GID} \ | ||||||
|  | 		docker.io/openapitools/openapi-diff:2.0.1 \ | ||||||
|  | 		--markdown /local/diff.md \ | ||||||
|  | 		/local/old_schema.yml /local/schema.yml | ||||||
|  | 	rm old_schema.yml | ||||||
|  |  | ||||||
| gen-clean: | gen-clean: | ||||||
| 	rm -rf web/api/src/ | 	rm -rf web/api/src/ | ||||||
| 	rm -rf api/ | 	rm -rf api/ | ||||||
|  |  | ||||||
| gen-client-web: | gen-client-ts: | ||||||
| 	docker run \ | 	docker run \ | ||||||
| 		--rm -v ${PWD}:/local \ | 		--rm -v ${PWD}:/local \ | ||||||
| 		--user ${UID}:${GID} \ | 		--user ${UID}:${GID} \ | ||||||
| 		openapitools/openapi-generator-cli:v6.0.0 generate \ | 		docker.io/openapitools/openapi-generator-cli:v6.0.0 generate \ | ||||||
| 		-i /local/schema.yml \ | 		-i /local/schema.yml \ | ||||||
| 		-g typescript-fetch \ | 		-g typescript-fetch \ | ||||||
| 		-o /local/gen-ts-api \ | 		-o /local/gen-ts-api \ | ||||||
| 		--additional-properties=typescriptThreePlus=true,supportsES6=true,npmName=@goauthentik/api,npmVersion=${NPM_VERSION} | 		--additional-properties=typescriptThreePlus=true,supportsES6=true,npmName=@goauthentik/api,npmVersion=${NPM_VERSION} \ | ||||||
|  | 		--git-repo-id authentik \ | ||||||
|  | 		--git-user-id goauthentik | ||||||
| 	mkdir -p web/node_modules/@goauthentik/api | 	mkdir -p web/node_modules/@goauthentik/api | ||||||
| 	\cp -fv scripts/web_api_readme.md gen-ts-api/README.md | 	\cp -fv scripts/web_api_readme.md gen-ts-api/README.md | ||||||
| 	cd gen-ts-api && npm i | 	cd gen-ts-api && npm i | ||||||
| @ -84,7 +106,7 @@ gen-client-go: | |||||||
| 	docker run \ | 	docker run \ | ||||||
| 		--rm -v ${PWD}:/local \ | 		--rm -v ${PWD}:/local \ | ||||||
| 		--user ${UID}:${GID} \ | 		--user ${UID}:${GID} \ | ||||||
| 		openapitools/openapi-generator-cli:v6.0.0 generate \ | 		docker.io/openapitools/openapi-generator-cli:v6.0.0 generate \ | ||||||
| 		-i /local/schema.yml \ | 		-i /local/schema.yml \ | ||||||
| 		-g go \ | 		-g go \ | ||||||
| 		-o /local/gen-go-api \ | 		-o /local/gen-go-api \ | ||||||
| @ -95,13 +117,7 @@ gen-client-go: | |||||||
| gen-dev-config: | gen-dev-config: | ||||||
| 	python -m scripts.generate_config | 	python -m scripts.generate_config | ||||||
|  |  | ||||||
| gen: gen-build gen-clean gen-client-web | gen: gen-build gen-clean gen-client-ts | ||||||
|  |  | ||||||
| migrate: |  | ||||||
| 	python -m lifecycle.migrate |  | ||||||
|  |  | ||||||
| run: |  | ||||||
| 	go run -v cmd/server/main.go |  | ||||||
|  |  | ||||||
| ######################### | ######################### | ||||||
| ## Web | ## Web | ||||||
| @ -148,25 +164,25 @@ website-watch: | |||||||
|  |  | ||||||
| # These targets are use by GitHub actions to allow usage of matrix | # These targets are use by GitHub actions to allow usage of matrix | ||||||
| # which makes the YAML File a lot smaller | # which makes the YAML File a lot smaller | ||||||
|  | PY_SOURCES=authentik tests lifecycle | ||||||
| ci--meta-debug: | ci--meta-debug: | ||||||
| 	python -V | 	python -V | ||||||
| 	node --version | 	node --version | ||||||
|  |  | ||||||
| ci-pylint: ci--meta-debug | ci-pylint: ci--meta-debug | ||||||
| 	pylint authentik tests lifecycle | 	pylint $(PY_SOURCES) | ||||||
|  |  | ||||||
| ci-black: ci--meta-debug | ci-black: ci--meta-debug | ||||||
| 	black --check authentik tests lifecycle | 	black --check $(PY_SOURCES) | ||||||
|  |  | ||||||
| ci-isort: ci--meta-debug | ci-isort: ci--meta-debug | ||||||
| 	isort --check authentik tests lifecycle | 	isort --check $(PY_SOURCES) | ||||||
|  |  | ||||||
| ci-bandit: ci--meta-debug | ci-bandit: ci--meta-debug | ||||||
| 	bandit -r authentik tests lifecycle | 	bandit -r $(PY_SOURCES) | ||||||
|  |  | ||||||
| ci-pyright: ci--meta-debug | ci-pyright: ci--meta-debug | ||||||
| 	pyright e2e lifecycle | 	./web/node_modules/.bin/pyright $(PY_SOURCES) | ||||||
|  |  | ||||||
| ci-pending-migrations: ci--meta-debug | ci-pending-migrations: ci--meta-debug | ||||||
| 	ak makemigrations --check | 	ak makemigrations --check | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
| from os import environ | from os import environ | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| __version__ = "2022.8.1" | __version__ = "2022.9.0" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -18,13 +18,13 @@ class AppSerializer(PassiveSerializer): | |||||||
|  |  | ||||||
|  |  | ||||||
| class AppsViewSet(ViewSet): | class AppsViewSet(ViewSet): | ||||||
|     """Read-only view set list all installed apps""" |     """Read-only view list all installed apps""" | ||||||
|  |  | ||||||
|     permission_classes = [IsAdminUser] |     permission_classes = [IsAdminUser] | ||||||
|  |  | ||||||
|     @extend_schema(responses={200: AppSerializer(many=True)}) |     @extend_schema(responses={200: AppSerializer(many=True)}) | ||||||
|     def list(self, request: Request) -> Response: |     def list(self, request: Request) -> Response: | ||||||
|         """List current messages and pass into Serializer""" |         """Read-only view list all installed apps""" | ||||||
|         data = [] |         data = [] | ||||||
|         for app in sorted(get_apps(), key=lambda app: app.name): |         for app in sorted(get_apps(), key=lambda app: app.name): | ||||||
|             data.append({"name": app.name, "label": app.verbose_name}) |             data.append({"name": app.name, "label": app.verbose_name}) | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ from django.contrib import messages | |||||||
| from django.http.response import Http404 | from django.http.response import Http404 | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import CharField, ChoiceField, DateTimeField, ListField | from rest_framework.fields import CharField, ChoiceField, DateTimeField, ListField | ||||||
| from rest_framework.permissions import IsAdminUser | from rest_framework.permissions import IsAdminUser | ||||||
| @ -58,7 +58,15 @@ class TaskViewSet(ViewSet): | |||||||
|         responses={ |         responses={ | ||||||
|             200: TaskSerializer(many=False), |             200: TaskSerializer(many=False), | ||||||
|             404: OpenApiResponse(description="Task not found"), |             404: OpenApiResponse(description="Task not found"), | ||||||
|         } |         }, | ||||||
|  |         parameters=[ | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 "id", | ||||||
|  |                 type=OpenApiTypes.STR, | ||||||
|  |                 location=OpenApiParameter.PATH, | ||||||
|  |                 required=True, | ||||||
|  |             ), | ||||||
|  |         ], | ||||||
|     ) |     ) | ||||||
|     # pylint: disable=invalid-name |     # pylint: disable=invalid-name | ||||||
|     def retrieve(self, request: Request, pk=None) -> Response: |     def retrieve(self, request: Request, pk=None) -> Response: | ||||||
| @ -81,6 +89,14 @@ class TaskViewSet(ViewSet): | |||||||
|             404: OpenApiResponse(description="Task not found"), |             404: OpenApiResponse(description="Task not found"), | ||||||
|             500: OpenApiResponse(description="Failed to retry task"), |             500: OpenApiResponse(description="Failed to retry task"), | ||||||
|         }, |         }, | ||||||
|  |         parameters=[ | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 "id", | ||||||
|  |                 type=OpenApiTypes.STR, | ||||||
|  |                 location=OpenApiParameter.PATH, | ||||||
|  |                 required=True, | ||||||
|  |             ), | ||||||
|  |         ], | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, methods=["post"]) |     @action(detail=True, methods=["post"]) | ||||||
|     # pylint: disable=invalid-name |     # pylint: disable=invalid-name | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """authentik admin app config""" | """authentik admin app config""" | ||||||
| from prometheus_client import Gauge, Info | from prometheus_client import Gauge, Info | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
| PROM_INFO = Info("authentik_version", "Currently running authentik version") | PROM_INFO = Info("authentik_version", "Currently running authentik version") | ||||||
| GAUGE_WORKERS = Gauge("authentik_admin_workers", "Currently connected workers") | GAUGE_WORKERS = Gauge("authentik_admin_workers", "Currently connected workers") | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ from authentik.providers.oauth2.models import RefreshToken | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| def validate_auth(header: bytes) -> str: | def validate_auth(header: bytes) -> Optional[str]: | ||||||
|     """Validate that the header is in a correct format, |     """Validate that the header is in a correct format, | ||||||
|     returns type and credentials""" |     returns type and credentials""" | ||||||
|     auth_credentials = header.decode().strip() |     auth_credentials = header.decode().strip() | ||||||
|  | |||||||
| @ -60,8 +60,28 @@ def postprocess_schema_responses(result, generator, **kwargs):  # noqa: W0613 | |||||||
|  |  | ||||||
|     for path in result["paths"].values(): |     for path in result["paths"].values(): | ||||||
|         for method in path.values(): |         for method in path.values(): | ||||||
|             method["responses"].setdefault("400", validation_error.ref) |             method["responses"].setdefault( | ||||||
|             method["responses"].setdefault("403", generic_error.ref) |                 "400", | ||||||
|  |                 { | ||||||
|  |                     "content": { | ||||||
|  |                         "application/json": { | ||||||
|  |                             "schema": validation_error.ref, | ||||||
|  |                         } | ||||||
|  |                     }, | ||||||
|  |                     "description": "", | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             method["responses"].setdefault( | ||||||
|  |                 "403", | ||||||
|  |                 { | ||||||
|  |                     "content": { | ||||||
|  |                         "application/json": { | ||||||
|  |                             "schema": generic_error.ref, | ||||||
|  |                         } | ||||||
|  |                     }, | ||||||
|  |                     "description": "", | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) |     result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,8 +1,7 @@ | |||||||
| """Serializer mixin for managed models""" | """Serializer mixin for managed models""" | ||||||
| from dataclasses import asdict |  | ||||||
|  |  | ||||||
| from drf_spectacular.utils import extend_schema, inline_serializer | from drf_spectacular.utils import extend_schema, inline_serializer | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
|  | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import CharField, DateTimeField, JSONField | from rest_framework.fields import CharField, DateTimeField, JSONField | ||||||
| from rest_framework.permissions import IsAdminUser | from rest_framework.permissions import IsAdminUser | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -11,11 +10,10 @@ from rest_framework.serializers import ListSerializer, ModelSerializer | |||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required | from authentik.api.decorators import permission_required | ||||||
| from authentik.blueprints.models import BlueprintInstance | from authentik.blueprints.models import BlueprintInstance, BlueprintRetrievalFailed | ||||||
| from authentik.blueprints.v1.tasks import BlueprintFile, apply_blueprint, blueprints_find | from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.events.utils import sanitize_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ManagedSerializer: | class ManagedSerializer: | ||||||
| @ -34,6 +32,14 @@ class MetadataSerializer(PassiveSerializer): | |||||||
| class BlueprintInstanceSerializer(ModelSerializer): | class BlueprintInstanceSerializer(ModelSerializer): | ||||||
|     """Info about a single blueprint instance file""" |     """Info about a single blueprint instance file""" | ||||||
|  |  | ||||||
|  |     def validate_path(self, path: str) -> str: | ||||||
|  |         """Ensure the path specified is retrievable""" | ||||||
|  |         try: | ||||||
|  |             BlueprintInstance(path=path).retrieve() | ||||||
|  |         except BlueprintRetrievalFailed as exc: | ||||||
|  |             raise ValidationError(exc) from exc | ||||||
|  |         return path | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|  |  | ||||||
|         model = BlueprintInstance |         model = BlueprintInstance | ||||||
| @ -85,8 +91,8 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | |||||||
|     @action(detail=False, pagination_class=None, filter_backends=[]) |     @action(detail=False, pagination_class=None, filter_backends=[]) | ||||||
|     def available(self, request: Request) -> Response: |     def available(self, request: Request) -> Response: | ||||||
|         """Get blueprints""" |         """Get blueprints""" | ||||||
|         files: list[BlueprintFile] = blueprints_find.delay().get() |         files: list[dict] = blueprints_find_dict.delay().get() | ||||||
|         return Response([sanitize_dict(asdict(file)) for file in files]) |         return Response(files) | ||||||
|  |  | ||||||
|     @permission_required("authentik_blueprints.view_blueprintinstance") |     @permission_required("authentik_blueprints.view_blueprintinstance") | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|  | |||||||
| @ -1,6 +1,46 @@ | |||||||
| """authentik Blueprints app""" | """authentik Blueprints app""" | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from importlib import import_module | ||||||
|  | from inspect import ismethod | ||||||
|  |  | ||||||
|  | from django.apps import AppConfig | ||||||
|  | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
|  | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ManagedAppConfig(AppConfig): | ||||||
|  |     """Basic reconciliation logic for apps""" | ||||||
|  |  | ||||||
|  |     _logger: BoundLogger | ||||||
|  |  | ||||||
|  |     def __init__(self, app_name: str, *args, **kwargs) -> None: | ||||||
|  |         super().__init__(app_name, *args, **kwargs) | ||||||
|  |         self._logger = get_logger().bind(app_name=app_name) | ||||||
|  |  | ||||||
|  |     def ready(self) -> None: | ||||||
|  |         self.reconcile() | ||||||
|  |         return super().ready() | ||||||
|  |  | ||||||
|  |     def import_module(self, path: str): | ||||||
|  |         """Load module""" | ||||||
|  |         import_module(path) | ||||||
|  |  | ||||||
|  |     def reconcile(self) -> None: | ||||||
|  |         """reconcile ourselves""" | ||||||
|  |         prefix = "reconcile_" | ||||||
|  |         for meth_name in dir(self): | ||||||
|  |             meth = getattr(self, meth_name) | ||||||
|  |             if not ismethod(meth): | ||||||
|  |                 continue | ||||||
|  |             if not meth_name.startswith(prefix): | ||||||
|  |                 continue | ||||||
|  |             name = meth_name.replace(prefix, "") | ||||||
|  |             try: | ||||||
|  |                 self._logger.debug("Starting reconciler", name=name) | ||||||
|  |                 meth() | ||||||
|  |                 self._logger.debug("Successfully reconciled", name=name) | ||||||
|  |             except (DatabaseError, ProgrammingError, InternalError) as exc: | ||||||
|  |                 self._logger.debug("Failed to run reconcile", name=name, exc=exc) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikBlueprintsConfig(ManagedAppConfig): | class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||||
| @ -20,3 +60,7 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): | |||||||
|         from authentik.blueprints.v1.tasks import blueprints_discover |         from authentik.blueprints.v1.tasks import blueprints_discover | ||||||
|  |  | ||||||
|         blueprints_discover.delay() |         blueprints_discover.delay() | ||||||
|  |  | ||||||
|  |     def import_models(self): | ||||||
|  |         super().import_models() | ||||||
|  |         self.import_module("authentik.blueprints.v1.meta.apply_blueprint") | ||||||
|  | |||||||
| @ -1,7 +1,10 @@ | |||||||
| """Apply blueprint from commandline""" | """Apply blueprint from commandline""" | ||||||
|  | from sys import exit as sys_exit | ||||||
|  |  | ||||||
| from django.core.management.base import BaseCommand, no_translations | from django.core.management.base import BaseCommand, no_translations | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.blueprints.models import BlueprintInstance | ||||||
| from authentik.blueprints.v1.importer import Importer | from authentik.blueprints.v1.importer import Importer | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -14,13 +17,14 @@ class Command(BaseCommand): | |||||||
|     def handle(self, *args, **options): |     def handle(self, *args, **options): | ||||||
|         """Apply all blueprints in order, abort when one fails to import""" |         """Apply all blueprints in order, abort when one fails to import""" | ||||||
|         for blueprint_path in options.get("blueprints", []): |         for blueprint_path in options.get("blueprints", []): | ||||||
|             with open(blueprint_path, "r", encoding="utf8") as blueprint_file: |             content = BlueprintInstance(path=blueprint_path).retrieve() | ||||||
|                 importer = Importer(blueprint_file.read()) |             importer = Importer(content) | ||||||
|             valid, logs = importer.validate() |             valid, logs = importer.validate() | ||||||
|             if not valid: |             if not valid: | ||||||
|                 for log in logs: |                 for log in logs: | ||||||
|                         LOGGER.debug(**log) |                     getattr(LOGGER, log.pop("log_level"))(**log) | ||||||
|                     raise ValueError("blueprint invalid") |                 self.stderr.write("blueprint invalid") | ||||||
|  |                 sys_exit(1) | ||||||
|             importer.apply() |             importer.apply() | ||||||
|  |  | ||||||
|     def add_arguments(self, parser): |     def add_arguments(self, parser): | ||||||
|  | |||||||
							
								
								
									
										17
									
								
								authentik/blueprints/management/commands/export_blueprint.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								authentik/blueprints/management/commands/export_blueprint.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | |||||||
|  | """Export blueprint of current authentik install""" | ||||||
|  | from django.core.management.base import BaseCommand, no_translations | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.blueprints.v1.exporter import Exporter | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Command(BaseCommand): | ||||||
|  |     """Export blueprint of current authentik install""" | ||||||
|  |  | ||||||
|  |     @no_translations | ||||||
|  |     def handle(self, *args, **options): | ||||||
|  |         """Export blueprint of current authentik install""" | ||||||
|  |         exporter = Exporter() | ||||||
|  |         self.stdout.write(exporter.export_to_string()) | ||||||
| @ -2,11 +2,11 @@ | |||||||
| from json import dumps, loads | from json import dumps, loads | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| from django.apps import apps |  | ||||||
| from django.core.management.base import BaseCommand, no_translations | from django.core.management.base import BaseCommand, no_translations | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.blueprints.v1.importer import is_model_allowed | from authentik.blueprints.v1.importer import is_model_allowed | ||||||
|  | from authentik.blueprints.v1.meta.registry import registry | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -28,8 +28,9 @@ class Command(BaseCommand): | |||||||
|     def set_model_allowed(self): |     def set_model_allowed(self): | ||||||
|         """Set model enum""" |         """Set model enum""" | ||||||
|         model_names = [] |         model_names = [] | ||||||
|         for model in apps.get_models(): |         for model in registry.get_models(): | ||||||
|             if not is_model_allowed(model): |             if not is_model_allowed(model): | ||||||
|                 continue |                 continue | ||||||
|             model_names.append(f"{model._meta.app_label}.{model._meta.model_name}") |             model_names.append(f"{model._meta.app_label}.{model._meta.model_name}") | ||||||
|  |         model_names.sort() | ||||||
|         self.schema["properties"]["entries"]["items"]["properties"]["model"]["enum"] = model_names |         self.schema["properties"]["entries"]["items"]["properties"]["model"]["enum"] = model_names | ||||||
|  | |||||||
| @ -41,8 +41,7 @@ | |||||||
|                 "$id": "#entry", |                 "$id": "#entry", | ||||||
|                 "type": "object", |                 "type": "object", | ||||||
|                 "required": [ |                 "required": [ | ||||||
|                     "model", |                     "model" | ||||||
|                     "identifiers" |  | ||||||
|                 ], |                 ], | ||||||
|                 "properties": { |                 "properties": { | ||||||
|                     "model": { |                     "model": { | ||||||
| @ -67,6 +66,7 @@ | |||||||
|                     }, |                     }, | ||||||
|                     "identifiers": { |                     "identifiers": { | ||||||
|                         "type": "object", |                         "type": "object", | ||||||
|  |                         "default": {}, | ||||||
|                         "properties": { |                         "properties": { | ||||||
|                             "pk": { |                             "pk": { | ||||||
|                                 "description": "Commonly available field, may not exist on all models", |                                 "description": "Commonly available field, may not exist on all models", | ||||||
|  | |||||||
| @ -1,44 +0,0 @@ | |||||||
| """Managed objects manager""" |  | ||||||
| from importlib import import_module |  | ||||||
| from inspect import ismethod |  | ||||||
|  |  | ||||||
| from django.apps import AppConfig |  | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError |  | ||||||
| from structlog.stdlib import BoundLogger, get_logger |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ManagedAppConfig(AppConfig): |  | ||||||
|     """Basic reconciliation logic for apps""" |  | ||||||
|  |  | ||||||
|     _logger: BoundLogger |  | ||||||
|  |  | ||||||
|     def __init__(self, app_name: str, *args, **kwargs) -> None: |  | ||||||
|         super().__init__(app_name, *args, **kwargs) |  | ||||||
|         self._logger = get_logger().bind(app_name=app_name) |  | ||||||
|  |  | ||||||
|     def ready(self) -> None: |  | ||||||
|         self.reconcile() |  | ||||||
|         return super().ready() |  | ||||||
|  |  | ||||||
|     def import_module(self, path: str): |  | ||||||
|         """Load module""" |  | ||||||
|         import_module(path) |  | ||||||
|  |  | ||||||
|     def reconcile(self) -> None: |  | ||||||
|         """reconcile ourselves""" |  | ||||||
|         prefix = "reconcile_" |  | ||||||
|         for meth_name in dir(self): |  | ||||||
|             meth = getattr(self, meth_name) |  | ||||||
|             if not ismethod(meth): |  | ||||||
|                 continue |  | ||||||
|             if not meth_name.startswith(prefix): |  | ||||||
|                 continue |  | ||||||
|             name = meth_name.replace(prefix, "") |  | ||||||
|             try: |  | ||||||
|                 self._logger.debug("Starting reconciler", name=name) |  | ||||||
|                 meth() |  | ||||||
|                 self._logger.debug("Successfully reconciled", name=name) |  | ||||||
|             except (DatabaseError, ProgrammingError, InternalError) as exc: |  | ||||||
|                 self._logger.debug("Failed to run reconcile", name=name, exc=exc) |  | ||||||
| @ -4,7 +4,7 @@ from glob import glob | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| import django.contrib.postgres.fields | import django.contrib.postgres.fields | ||||||
| from dacite import from_dict | from dacite.core import from_dict | ||||||
| from django.apps.registry import Apps | from django.apps.registry import Apps | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db import migrations, models | from django.db import migrations, models | ||||||
| @ -113,7 +113,8 @@ class Migration(migrations.Migration): | |||||||
|                             ("error", "Error"), |                             ("error", "Error"), | ||||||
|                             ("orphaned", "Orphaned"), |                             ("orphaned", "Orphaned"), | ||||||
|                             ("unknown", "Unknown"), |                             ("unknown", "Unknown"), | ||||||
|                         ] |                         ], | ||||||
|  |                         default="unknown", | ||||||
|                     ), |                     ), | ||||||
|                 ), |                 ), | ||||||
|                 ("enabled", models.BooleanField(default=True)), |                 ("enabled", models.BooleanField(default=True)), | ||||||
|  | |||||||
| @ -1,12 +1,36 @@ | |||||||
| """Managed Object models""" | """blueprint models""" | ||||||
|  | from pathlib import Path | ||||||
|  | from urllib.parse import urlparse | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.contrib.postgres.fields import ArrayField | from django.contrib.postgres.fields import ArrayField | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from opencontainers.distribution.reggie import ( | ||||||
|  |     NewClient, | ||||||
|  |     WithDebug, | ||||||
|  |     WithDefaultName, | ||||||
|  |     WithDigest, | ||||||
|  |     WithReference, | ||||||
|  |     WithUserAgent, | ||||||
|  |     WithUsernamePassword, | ||||||
|  | ) | ||||||
|  | from requests.exceptions import RequestException | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  | from structlog import get_logger | ||||||
|  |  | ||||||
|  | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.models import CreatedUpdatedModel, SerializerModel | from authentik.lib.models import CreatedUpdatedModel, SerializerModel | ||||||
|  | from authentik.lib.sentry import SentryIgnoredException | ||||||
|  | from authentik.lib.utils.http import authentik_user_agent | ||||||
|  |  | ||||||
|  | OCI_MEDIA_TYPE = "application/vnd.goauthentik.blueprint.v1+yaml" | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BlueprintRetrievalFailed(SentryIgnoredException): | ||||||
|  |     """Error raised when we're unable to fetch the blueprint contents, whether it be HTTP files | ||||||
|  |     not being accessible or local files not being readable""" | ||||||
|  |  | ||||||
|  |  | ||||||
| class ManagedModel(models.Model): | class ManagedModel(models.Model): | ||||||
| @ -54,10 +78,70 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|     context = models.JSONField(default=dict) |     context = models.JSONField(default=dict) | ||||||
|     last_applied = models.DateTimeField(auto_now=True) |     last_applied = models.DateTimeField(auto_now=True) | ||||||
|     last_applied_hash = models.TextField() |     last_applied_hash = models.TextField() | ||||||
|     status = models.TextField(choices=BlueprintInstanceStatus.choices) |     status = models.TextField( | ||||||
|  |         choices=BlueprintInstanceStatus.choices, default=BlueprintInstanceStatus.UNKNOWN | ||||||
|  |     ) | ||||||
|     enabled = models.BooleanField(default=True) |     enabled = models.BooleanField(default=True) | ||||||
|     managed_models = ArrayField(models.TextField(), default=list) |     managed_models = ArrayField(models.TextField(), default=list) | ||||||
|  |  | ||||||
|  |     def retrieve_oci(self) -> str: | ||||||
|  |         """Get blueprint from an OCI registry""" | ||||||
|  |         url = urlparse(self.path) | ||||||
|  |         ref = "latest" | ||||||
|  |         path = url.path[1:] | ||||||
|  |         if ":" in url.path: | ||||||
|  |             path, _, ref = path.partition(":") | ||||||
|  |         client = NewClient( | ||||||
|  |             f"{url.scheme}://{url.hostname}", | ||||||
|  |             WithUserAgent(authentik_user_agent()), | ||||||
|  |             WithUsernamePassword(url.username, url.password), | ||||||
|  |             WithDefaultName(path), | ||||||
|  |             WithDebug(True), | ||||||
|  |         ) | ||||||
|  |         LOGGER.info("Fetching OCI manifests for blueprint", instance=self) | ||||||
|  |         manifest_request = client.NewRequest( | ||||||
|  |             "GET", | ||||||
|  |             "/v2/<name>/manifests/<reference>", | ||||||
|  |             WithReference(ref), | ||||||
|  |         ).SetHeader("Accept", "application/vnd.oci.image.manifest.v1+json") | ||||||
|  |         try: | ||||||
|  |             manifest_response = client.Do(manifest_request) | ||||||
|  |             manifest_response.raise_for_status() | ||||||
|  |         except RequestException as exc: | ||||||
|  |             raise BlueprintRetrievalFailed(exc) from exc | ||||||
|  |         manifest = manifest_response.json() | ||||||
|  |         if "errors" in manifest: | ||||||
|  |             raise BlueprintRetrievalFailed(manifest["errors"]) | ||||||
|  |  | ||||||
|  |         blob = None | ||||||
|  |         for layer in manifest.get("layers", []): | ||||||
|  |             if layer.get("mediaType", "") == OCI_MEDIA_TYPE: | ||||||
|  |                 blob = layer.get("digest") | ||||||
|  |                 LOGGER.debug("Found layer with matching media type", instance=self, blob=blob) | ||||||
|  |         if not blob: | ||||||
|  |             raise BlueprintRetrievalFailed("Blob not found") | ||||||
|  |  | ||||||
|  |         blob_request = client.NewRequest( | ||||||
|  |             "GET", | ||||||
|  |             "/v2/<name>/blobs/<digest>", | ||||||
|  |             WithDigest(blob), | ||||||
|  |         ) | ||||||
|  |         try: | ||||||
|  |             blob_response = client.Do(blob_request) | ||||||
|  |             blob_response.raise_for_status() | ||||||
|  |             return blob_response.text | ||||||
|  |         except RequestException as exc: | ||||||
|  |             raise BlueprintRetrievalFailed(exc) from exc | ||||||
|  |  | ||||||
|  |     def retrieve(self) -> str: | ||||||
|  |         """Retrieve blueprint contents""" | ||||||
|  |         full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path)) | ||||||
|  |         if full_path.exists(): | ||||||
|  |             LOGGER.debug("Blueprint path exists locally", instance=self) | ||||||
|  |             with full_path.open("r", encoding="utf-8") as _file: | ||||||
|  |                 return _file.read() | ||||||
|  |         return self.retrieve_oci() | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> Serializer: |     def serializer(self) -> Serializer: | ||||||
|         from authentik.blueprints.api import BlueprintInstanceSerializer |         from authentik.blueprints.api import BlueprintInstanceSerializer | ||||||
|  | |||||||
| @ -5,7 +5,8 @@ from typing import Callable | |||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  | from authentik.blueprints.models import BlueprintInstance | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -19,11 +20,9 @@ def apply_blueprint(*files: str): | |||||||
|  |  | ||||||
|         @wraps(func) |         @wraps(func) | ||||||
|         def wrapper(*args, **kwargs): |         def wrapper(*args, **kwargs): | ||||||
|             base_path = Path(CONFIG.y("blueprints_dir")) |  | ||||||
|             for file in files: |             for file in files: | ||||||
|                 full_path = Path(base_path, file) |                 content = BlueprintInstance(path=file).retrieve() | ||||||
|                 with full_path.open("r", encoding="utf-8") as _file: |                 Importer(content).apply() | ||||||
|                     Importer(_file.read()).apply() |  | ||||||
|             return func(*args, **kwargs) |             return func(*args, **kwargs) | ||||||
|  |  | ||||||
|         return wrapper |         return wrapper | ||||||
|  | |||||||
							
								
								
									
										97
									
								
								authentik/blueprints/tests/test_oci.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								authentik/blueprints/tests/test_oci.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,97 @@ | |||||||
|  | """Test blueprints OCI""" | ||||||
|  | from django.test import TransactionTestCase | ||||||
|  | from requests_mock import Mocker | ||||||
|  |  | ||||||
|  | from authentik.blueprints.models import OCI_MEDIA_TYPE, BlueprintInstance, BlueprintRetrievalFailed | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestBlueprintOCI(TransactionTestCase): | ||||||
|  |     """Test Blueprints OCI Tasks""" | ||||||
|  |  | ||||||
|  |     def test_successful(self): | ||||||
|  |         """Successful retrieval""" | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.get( | ||||||
|  |                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||||
|  |                 json={ | ||||||
|  |                     "layers": [ | ||||||
|  |                         { | ||||||
|  |                             "mediaType": OCI_MEDIA_TYPE, | ||||||
|  |                             "digest": "foo", | ||||||
|  |                         } | ||||||
|  |                     ] | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             mocker.get("https://ghcr.io/v2/goauthentik/blueprints/test/blobs/foo", text="foo") | ||||||
|  |  | ||||||
|  |             self.assertEqual( | ||||||
|  |                 BlueprintInstance( | ||||||
|  |                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||||
|  |                 ).retrieve_oci(), | ||||||
|  |                 "foo", | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_manifests_error(self): | ||||||
|  |         """Test manifests request erroring""" | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.get( | ||||||
|  |                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", status_code=401 | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             with self.assertRaises(BlueprintRetrievalFailed): | ||||||
|  |                 BlueprintInstance( | ||||||
|  |                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||||
|  |                 ).retrieve_oci() | ||||||
|  |  | ||||||
|  |     def test_manifests_error_response(self): | ||||||
|  |         """Test manifests request erroring""" | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.get( | ||||||
|  |                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||||
|  |                 json={"errors": ["foo"]}, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             with self.assertRaises(BlueprintRetrievalFailed): | ||||||
|  |                 BlueprintInstance( | ||||||
|  |                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||||
|  |                 ).retrieve_oci() | ||||||
|  |  | ||||||
|  |     def test_no_matching_blob(self): | ||||||
|  |         """Successful retrieval""" | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.get( | ||||||
|  |                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||||
|  |                 json={ | ||||||
|  |                     "layers": [ | ||||||
|  |                         { | ||||||
|  |                             "mediaType": OCI_MEDIA_TYPE + "foo", | ||||||
|  |                             "digest": "foo", | ||||||
|  |                         } | ||||||
|  |                     ] | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             with self.assertRaises(BlueprintRetrievalFailed): | ||||||
|  |                 BlueprintInstance( | ||||||
|  |                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||||
|  |                 ).retrieve_oci() | ||||||
|  |  | ||||||
|  |     def test_blob_error(self): | ||||||
|  |         """Successful retrieval""" | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.get( | ||||||
|  |                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||||
|  |                 json={ | ||||||
|  |                     "layers": [ | ||||||
|  |                         { | ||||||
|  |                             "mediaType": OCI_MEDIA_TYPE, | ||||||
|  |                             "digest": "foo", | ||||||
|  |                         } | ||||||
|  |                     ] | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             mocker.get("https://ghcr.io/v2/goauthentik/blueprints/test/blobs/foo", status_code=401) | ||||||
|  |  | ||||||
|  |             with self.assertRaises(BlueprintRetrievalFailed): | ||||||
|  |                 BlueprintInstance( | ||||||
|  |                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||||
|  |                 ).retrieve_oci() | ||||||
| @ -1,17 +1,16 @@ | |||||||
| """test packaged blueprints""" | """test packaged blueprints""" | ||||||
| from glob import glob |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Callable | from typing import Callable | ||||||
| 
 | 
 | ||||||
| from django.test import TransactionTestCase | from django.test import TransactionTestCase | ||||||
| from django.utils.text import slugify |  | ||||||
| 
 | 
 | ||||||
|  | from authentik.blueprints.models import BlueprintInstance | ||||||
| from authentik.blueprints.tests import apply_blueprint | from authentik.blueprints.tests import apply_blueprint | ||||||
| from authentik.blueprints.v1.importer import Importer | from authentik.blueprints.v1.importer import Importer | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestBundled(TransactionTestCase): | class TestPackaged(TransactionTestCase): | ||||||
|     """Empty class, test methods are added dynamically""" |     """Empty class, test methods are added dynamically""" | ||||||
| 
 | 
 | ||||||
|     @apply_blueprint("default/90-default-tenant.yaml") |     @apply_blueprint("default/90-default-tenant.yaml") | ||||||
| @ -20,18 +19,20 @@ class TestBundled(TransactionTestCase): | |||||||
|         self.assertTrue(Tenant.objects.filter(domain="authentik-default").exists()) |         self.assertTrue(Tenant.objects.filter(domain="authentik-default").exists()) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def blueprint_tester(file_name: str) -> Callable: | def blueprint_tester(file_name: Path) -> Callable: | ||||||
|     """This is used instead of subTest for better visibility""" |     """This is used instead of subTest for better visibility""" | ||||||
| 
 | 
 | ||||||
|     def tester(self: TestBundled): |     def tester(self: TestPackaged): | ||||||
|         with open(file_name, "r", encoding="utf8") as flow_yaml: |         base = Path("blueprints/") | ||||||
|             importer = Importer(flow_yaml.read()) |         rel_path = Path(file_name).relative_to(base) | ||||||
|  |         importer = Importer(BlueprintInstance(path=str(rel_path)).retrieve()) | ||||||
|         self.assertTrue(importer.validate()[0]) |         self.assertTrue(importer.validate()[0]) | ||||||
|         self.assertTrue(importer.apply()) |         self.assertTrue(importer.apply()) | ||||||
| 
 | 
 | ||||||
|     return tester |     return tester | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| for flow_file in glob("blueprints/**/*.yaml", recursive=True): | for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | ||||||
|     method_name = slugify(Path(flow_file).stem).replace("-", "_").replace(".", "_") |     if "local" in str(blueprint_file): | ||||||
|     setattr(TestBundled, f"test_flow_{method_name}", blueprint_tester(flow_file)) |         continue | ||||||
|  |     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) | ||||||
| @ -1,7 +1,7 @@ | |||||||
| """Test blueprints v1""" | """Test blueprints v1""" | ||||||
| from django.test import TransactionTestCase | from django.test import TransactionTestCase | ||||||
|  |  | ||||||
| from authentik.blueprints.v1.exporter import Exporter | from authentik.blueprints.v1.exporter import FlowExporter | ||||||
| from authentik.blueprints.v1.importer import Importer, transaction_rollback | from authentik.blueprints.v1.importer import Importer, transaction_rollback | ||||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding | from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| @ -70,7 +70,7 @@ class TestBlueprintsV1(TransactionTestCase): | |||||||
|                 order=0, |                 order=0, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             exporter = Exporter(flow) |             exporter = FlowExporter(flow) | ||||||
|             export = exporter.export() |             export = exporter.export() | ||||||
|             self.assertEqual(len(export.entries), 3) |             self.assertEqual(len(export.entries), 3) | ||||||
|             export_yaml = exporter.export_to_string() |             export_yaml = exporter.export_to_string() | ||||||
| @ -126,7 +126,7 @@ class TestBlueprintsV1(TransactionTestCase): | |||||||
|             fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0) |             fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0) | ||||||
|             PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0) |             PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0) | ||||||
|  |  | ||||||
|             exporter = Exporter(flow) |             exporter = FlowExporter(flow) | ||||||
|             export_yaml = exporter.export_to_string() |             export_yaml = exporter.export_to_string() | ||||||
|  |  | ||||||
|         importer = Importer(export_yaml) |         importer = Importer(export_yaml) | ||||||
| @ -169,7 +169,7 @@ class TestBlueprintsV1(TransactionTestCase): | |||||||
|  |  | ||||||
|             FlowStageBinding.objects.create(target=flow, stage=first_stage, order=0) |             FlowStageBinding.objects.create(target=flow, stage=first_stage, order=0) | ||||||
|  |  | ||||||
|             exporter = Exporter(flow) |             exporter = FlowExporter(flow) | ||||||
|             export_yaml = exporter.export_to_string() |             export_yaml = exporter.export_to_string() | ||||||
|  |  | ||||||
|         importer = Importer(export_yaml) |         importer = Importer(export_yaml) | ||||||
|  | |||||||
| @ -1,4 +1,5 @@ | |||||||
| """Test blueprints v1 tasks""" | """Test blueprints v1 tasks""" | ||||||
|  | from hashlib import sha512 | ||||||
| from tempfile import NamedTemporaryFile, mkdtemp | from tempfile import NamedTemporaryFile, mkdtemp | ||||||
|  |  | ||||||
| from django.test import TransactionTestCase | from django.test import TransactionTestCase | ||||||
| @ -36,25 +37,32 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|     @CONFIG.patch("blueprints_dir", TMP) |     @CONFIG.patch("blueprints_dir", TMP) | ||||||
|     def test_valid(self): |     def test_valid(self): | ||||||
|         """Test valid file""" |         """Test valid file""" | ||||||
|  |         blueprint_id = generate_id() | ||||||
|         with NamedTemporaryFile(mode="w+", suffix=".yaml", dir=TMP) as file: |         with NamedTemporaryFile(mode="w+", suffix=".yaml", dir=TMP) as file: | ||||||
|             file.write( |             file.write( | ||||||
|                 dump( |                 dump( | ||||||
|                     { |                     { | ||||||
|                         "version": 1, |                         "version": 1, | ||||||
|                         "entries": [], |                         "entries": [], | ||||||
|  |                         "metadata": { | ||||||
|  |                             "name": blueprint_id, | ||||||
|  |                         }, | ||||||
|                     } |                     } | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|  |             file.seek(0) | ||||||
|  |             file_hash = sha512(file.read().encode()).hexdigest() | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discover()  # pylint: disable=no-value-for-parameter |             blueprints_discover()  # pylint: disable=no-value-for-parameter | ||||||
|  |             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() | ||||||
|  |             self.assertEqual(instance.last_applied_hash, file_hash) | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 BlueprintInstance.objects.first().last_applied_hash, |                 instance.metadata, | ||||||
|                 ( |                 { | ||||||
|                     "e52bb445b03cd36057258dc9f0ce0fbed8278498ee1470e45315293e5f026d1b" |                     "name": blueprint_id, | ||||||
|                     "d1f9b3526871c0003f5c07be5c3316d9d4a08444bd8fed1b3f03294e51e44522" |                     "labels": {}, | ||||||
|                 ), |                 }, | ||||||
|             ) |             ) | ||||||
|             self.assertEqual(BlueprintInstance.objects.first().metadata, {}) |  | ||||||
|  |  | ||||||
|     @CONFIG.patch("blueprints_dir", TMP) |     @CONFIG.patch("blueprints_dir", TMP) | ||||||
|     def test_valid_updated(self): |     def test_valid_updated(self): | ||||||
|  | |||||||
| @ -27,7 +27,7 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]: | |||||||
|             continue |             continue | ||||||
|         if _field.read_only: |         if _field.read_only: | ||||||
|             data.pop(field_name, None) |             data.pop(field_name, None) | ||||||
|         if _field.default == data.get(field_name, None): |         if _field.get_initial() == data.get(field_name, None): | ||||||
|             data.pop(field_name, None) |             data.pop(field_name, None) | ||||||
|         if field_name.endswith("_set"): |         if field_name.endswith("_set"): | ||||||
|             data.pop(field_name, None) |             data.pop(field_name, None) | ||||||
| @ -35,21 +35,28 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]: | |||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class BlueprintEntry: | class BlueprintEntryState: | ||||||
|     """Single entry of a bundle""" |     """State of a single instance""" | ||||||
|  |  | ||||||
|  |     instance: Optional[Model] = None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class BlueprintEntry: | ||||||
|  |     """Single entry of a blueprint""" | ||||||
|  |  | ||||||
|     identifiers: dict[str, Any] |  | ||||||
|     model: str |     model: str | ||||||
|  |     identifiers: dict[str, Any] = field(default_factory=dict) | ||||||
|     attrs: Optional[dict[str, Any]] = field(default_factory=dict) |     attrs: Optional[dict[str, Any]] = field(default_factory=dict) | ||||||
|  |  | ||||||
|     # pylint: disable=invalid-name |     # pylint: disable=invalid-name | ||||||
|     id: Optional[str] = None |     id: Optional[str] = None | ||||||
|  |  | ||||||
|     _instance: Optional[Model] = None |     _state: BlueprintEntryState = field(default_factory=BlueprintEntryState) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry": |     def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry": | ||||||
|         """Convert a SerializerModel instance to a Bundle Entry""" |         """Convert a SerializerModel instance to a blueprint Entry""" | ||||||
|         identifiers = { |         identifiers = { | ||||||
|             "pk": model.pk, |             "pk": model.pk, | ||||||
|         } |         } | ||||||
| @ -98,9 +105,9 @@ class Blueprint: | |||||||
|  |  | ||||||
|     version: int = field(default=1) |     version: int = field(default=1) | ||||||
|     entries: list[BlueprintEntry] = field(default_factory=list) |     entries: list[BlueprintEntry] = field(default_factory=list) | ||||||
|  |     context: dict = field(default_factory=dict) | ||||||
|  |  | ||||||
|     metadata: Optional[BlueprintMetadata] = field(default=None) |     metadata: Optional[BlueprintMetadata] = field(default=None) | ||||||
|     context: Optional[dict] = field(default_factory=dict) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class YAMLTag: | class YAMLTag: | ||||||
| @ -123,15 +130,15 @@ class KeyOf(YAMLTag): | |||||||
|  |  | ||||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: |     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||||
|         for _entry in blueprint.entries: |         for _entry in blueprint.entries: | ||||||
|             if _entry.id == self.id_from and _entry._instance: |             if _entry.id == self.id_from and _entry._state.instance: | ||||||
|                 # Special handling for PolicyBindingModels, as they'll have a different PK |                 # Special handling for PolicyBindingModels, as they'll have a different PK | ||||||
|                 # which is used when creating policy bindings |                 # which is used when creating policy bindings | ||||||
|                 if ( |                 if ( | ||||||
|                     isinstance(_entry._instance, PolicyBindingModel) |                     isinstance(_entry._state.instance, PolicyBindingModel) | ||||||
|                     and entry.model.lower() == "authentik_policies.policybinding" |                     and entry.model.lower() == "authentik_policies.policybinding" | ||||||
|                 ): |                 ): | ||||||
|                     return _entry._instance.pbm_uuid |                     return _entry._state.instance.pbm_uuid | ||||||
|                 return _entry._instance.pk |                 return _entry._state.instance.pk | ||||||
|         raise ValueError( |         raise ValueError( | ||||||
|             f"KeyOf: failed to find entry with `id` of `{self.id_from}` and a model instance" |             f"KeyOf: failed to find entry with `id` of `{self.id_from}` and a model instance" | ||||||
|         ) |         ) | ||||||
| @ -176,8 +183,6 @@ class Format(YAMLTag): | |||||||
|  |  | ||||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: |     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||||
|         try: |         try: | ||||||
|             print(self.format_string) |  | ||||||
|             print(self.args) |  | ||||||
|             return self.format_string % tuple(self.args) |             return self.format_string % tuple(self.args) | ||||||
|         except TypeError as exc: |         except TypeError as exc: | ||||||
|             raise EntryInvalidError(exc) |             raise EntryInvalidError(exc) | ||||||
| @ -225,7 +230,13 @@ class BlueprintDumper(SafeDumper): | |||||||
|  |  | ||||||
|     def represent(self, data) -> None: |     def represent(self, data) -> None: | ||||||
|         if is_dataclass(data): |         if is_dataclass(data): | ||||||
|             data = asdict(data) |  | ||||||
|  |             def factory(items): | ||||||
|  |                 final_dict = dict(items) | ||||||
|  |                 final_dict.pop("_state", None) | ||||||
|  |                 return final_dict | ||||||
|  |  | ||||||
|  |             data = asdict(data, dict_factory=factory) | ||||||
|         return super().represent(data) |         return super().represent(data) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -242,3 +253,9 @@ class BlueprintLoader(SafeLoader): | |||||||
|  |  | ||||||
| class EntryInvalidError(SentryIgnoredException): | class EntryInvalidError(SentryIgnoredException): | ||||||
|     """Error raised when an entry is invalid""" |     """Error raised when an entry is invalid""" | ||||||
|  |  | ||||||
|  |     serializer_errors: Optional[dict] | ||||||
|  |  | ||||||
|  |     def __init__(self, *args: object, serializer_errors: Optional[dict] = None) -> None: | ||||||
|  |         super().__init__(*args) | ||||||
|  |         self.serializer_errors = serializer_errors | ||||||
|  | |||||||
| @ -1,11 +1,24 @@ | |||||||
| """Flow exporter""" | """Blueprint exporter""" | ||||||
| from typing import Iterator | from typing import Iterable | ||||||
| from uuid import UUID | from uuid import UUID | ||||||
|  |  | ||||||
| from django.db.models import Q | from django.apps import apps | ||||||
|  | from django.contrib.auth import get_user_model | ||||||
|  | from django.db.models import Model, Q, QuerySet | ||||||
|  | from django.utils.timezone import now | ||||||
|  | from django.utils.translation import gettext as _ | ||||||
|  | from guardian.shortcuts import get_anonymous_user | ||||||
| from yaml import dump | from yaml import dump | ||||||
|  |  | ||||||
| from authentik.blueprints.v1.common import Blueprint, BlueprintDumper, BlueprintEntry | from authentik.blueprints.v1.common import ( | ||||||
|  |     Blueprint, | ||||||
|  |     BlueprintDumper, | ||||||
|  |     BlueprintEntry, | ||||||
|  |     BlueprintMetadata, | ||||||
|  | ) | ||||||
|  | from authentik.blueprints.v1.importer import is_model_allowed | ||||||
|  | from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_GENERATED | ||||||
|  | from authentik.events.models import Event | ||||||
| from authentik.flows.models import Flow, FlowStageBinding, Stage | from authentik.flows.models import Flow, FlowStageBinding, Stage | ||||||
| from authentik.policies.models import Policy, PolicyBinding | from authentik.policies.models import Policy, PolicyBinding | ||||||
| from authentik.stages.prompt.models import PromptStage | from authentik.stages.prompt.models import PromptStage | ||||||
| @ -14,6 +27,55 @@ from authentik.stages.prompt.models import PromptStage | |||||||
| class Exporter: | class Exporter: | ||||||
|     """Export flow with attached stages into yaml""" |     """Export flow with attached stages into yaml""" | ||||||
|  |  | ||||||
|  |     excluded_models: list[type[Model]] = [] | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self.excluded_models = [ | ||||||
|  |             Event, | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |     def get_entries(self) -> Iterable[BlueprintEntry]: | ||||||
|  |         """Get blueprint entries""" | ||||||
|  |         for model in apps.get_models(): | ||||||
|  |             if not is_model_allowed(model): | ||||||
|  |                 continue | ||||||
|  |             if model in self.excluded_models: | ||||||
|  |                 continue | ||||||
|  |             for obj in self.get_model_instances(model): | ||||||
|  |                 yield BlueprintEntry.from_model(obj) | ||||||
|  |  | ||||||
|  |     def get_model_instances(self, model: type[Model]) -> QuerySet: | ||||||
|  |         """Return a queryset for `model`. Can be used to filter some | ||||||
|  |         objects on some models""" | ||||||
|  |         if model == get_user_model(): | ||||||
|  |             return model.objects.exclude(pk=get_anonymous_user().pk) | ||||||
|  |         return model.objects.all() | ||||||
|  |  | ||||||
|  |     def _pre_export(self, blueprint: Blueprint): | ||||||
|  |         """Hook to run anything pre-export""" | ||||||
|  |  | ||||||
|  |     def export(self) -> Blueprint: | ||||||
|  |         """Create a list of all objects and create a blueprint""" | ||||||
|  |         blueprint = Blueprint() | ||||||
|  |         self._pre_export(blueprint) | ||||||
|  |         blueprint.metadata = BlueprintMetadata( | ||||||
|  |             name=_("authentik Export - %(date)s" % {"date": str(now())}), | ||||||
|  |             labels={ | ||||||
|  |                 LABEL_AUTHENTIK_GENERATED: "true", | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         blueprint.entries = list(self.get_entries()) | ||||||
|  |         return blueprint | ||||||
|  |  | ||||||
|  |     def export_to_string(self) -> str: | ||||||
|  |         """Call export and convert it to yaml""" | ||||||
|  |         blueprint = self.export() | ||||||
|  |         return dump(blueprint, Dumper=BlueprintDumper) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FlowExporter(Exporter): | ||||||
|  |     """Exporter customised to only return objects related to `flow`""" | ||||||
|  |  | ||||||
|     flow: Flow |     flow: Flow | ||||||
|     with_policies: bool |     with_policies: bool | ||||||
|     with_stage_prompts: bool |     with_stage_prompts: bool | ||||||
| @ -21,17 +83,20 @@ class Exporter: | |||||||
|     pbm_uuids: list[UUID] |     pbm_uuids: list[UUID] | ||||||
|  |  | ||||||
|     def __init__(self, flow: Flow): |     def __init__(self, flow: Flow): | ||||||
|  |         super().__init__() | ||||||
|         self.flow = flow |         self.flow = flow | ||||||
|         self.with_policies = True |         self.with_policies = True | ||||||
|         self.with_stage_prompts = True |         self.with_stage_prompts = True | ||||||
|  |  | ||||||
|     def _prepare_pbm(self): |     def _pre_export(self, blueprint: Blueprint): | ||||||
|  |         if not self.with_policies: | ||||||
|  |             return | ||||||
|         self.pbm_uuids = [self.flow.pbm_uuid] |         self.pbm_uuids = [self.flow.pbm_uuid] | ||||||
|         self.pbm_uuids += FlowStageBinding.objects.filter(target=self.flow).values_list( |         self.pbm_uuids += FlowStageBinding.objects.filter(target=self.flow).values_list( | ||||||
|             "pbm_uuid", flat=True |             "pbm_uuid", flat=True | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def walk_stages(self) -> Iterator[BlueprintEntry]: |     def walk_stages(self) -> Iterable[BlueprintEntry]: | ||||||
|         """Convert all stages attached to self.flow into BlueprintEntry objects""" |         """Convert all stages attached to self.flow into BlueprintEntry objects""" | ||||||
|         stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses() |         stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses() | ||||||
|         for stage in stages: |         for stage in stages: | ||||||
| @ -39,13 +104,13 @@ class Exporter: | |||||||
|                 pass |                 pass | ||||||
|             yield BlueprintEntry.from_model(stage, "name") |             yield BlueprintEntry.from_model(stage, "name") | ||||||
|  |  | ||||||
|     def walk_stage_bindings(self) -> Iterator[BlueprintEntry]: |     def walk_stage_bindings(self) -> Iterable[BlueprintEntry]: | ||||||
|         """Convert all bindings attached to self.flow into BlueprintEntry objects""" |         """Convert all bindings attached to self.flow into BlueprintEntry objects""" | ||||||
|         bindings = FlowStageBinding.objects.filter(target=self.flow).select_related() |         bindings = FlowStageBinding.objects.filter(target=self.flow).select_related() | ||||||
|         for binding in bindings: |         for binding in bindings: | ||||||
|             yield BlueprintEntry.from_model(binding, "target", "stage", "order") |             yield BlueprintEntry.from_model(binding, "target", "stage", "order") | ||||||
|  |  | ||||||
|     def walk_policies(self) -> Iterator[BlueprintEntry]: |     def walk_policies(self) -> Iterable[BlueprintEntry]: | ||||||
|         """Walk over all policies. This is done at the beginning of the export for stages that have |         """Walk over all policies. This is done at the beginning of the export for stages that have | ||||||
|         a direct foreign key to a policy.""" |         a direct foreign key to a policy.""" | ||||||
|         # Special case for PromptStage as that has a direct M2M to policy, we have to ensure |         # Special case for PromptStage as that has a direct M2M to policy, we have to ensure | ||||||
| @ -56,37 +121,29 @@ class Exporter: | |||||||
|         for policy in policies: |         for policy in policies: | ||||||
|             yield BlueprintEntry.from_model(policy) |             yield BlueprintEntry.from_model(policy) | ||||||
|  |  | ||||||
|     def walk_policy_bindings(self) -> Iterator[BlueprintEntry]: |     def walk_policy_bindings(self) -> Iterable[BlueprintEntry]: | ||||||
|         """Walk over all policybindings relative to us. This is run at the end of the export, as |         """Walk over all policybindings relative to us. This is run at the end of the export, as | ||||||
|         we are sure all objects exist now.""" |         we are sure all objects exist now.""" | ||||||
|         bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related() |         bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related() | ||||||
|         for binding in bindings: |         for binding in bindings: | ||||||
|             yield BlueprintEntry.from_model(binding, "policy", "target", "order") |             yield BlueprintEntry.from_model(binding, "policy", "target", "order") | ||||||
|  |  | ||||||
|     def walk_stage_prompts(self) -> Iterator[BlueprintEntry]: |     def walk_stage_prompts(self) -> Iterable[BlueprintEntry]: | ||||||
|         """Walk over all prompts associated with any PromptStages""" |         """Walk over all prompts associated with any PromptStages""" | ||||||
|         prompt_stages = PromptStage.objects.filter(flow=self.flow) |         prompt_stages = PromptStage.objects.filter(flow=self.flow) | ||||||
|         for stage in prompt_stages: |         for stage in prompt_stages: | ||||||
|             for prompt in stage.fields.all(): |             for prompt in stage.fields.all(): | ||||||
|                 yield BlueprintEntry.from_model(prompt) |                 yield BlueprintEntry.from_model(prompt) | ||||||
|  |  | ||||||
|     def export(self) -> Blueprint: |     def get_entries(self) -> Iterable[BlueprintEntry]: | ||||||
|         """Create a list of all objects including the flow""" |         entries = [] | ||||||
|         if self.with_policies: |         entries.append(BlueprintEntry.from_model(self.flow, "slug")) | ||||||
|             self._prepare_pbm() |  | ||||||
|         bundle = Blueprint() |  | ||||||
|         bundle.entries.append(BlueprintEntry.from_model(self.flow, "slug")) |  | ||||||
|         if self.with_stage_prompts: |         if self.with_stage_prompts: | ||||||
|             bundle.entries.extend(self.walk_stage_prompts()) |             entries.extend(self.walk_stage_prompts()) | ||||||
|         if self.with_policies: |         if self.with_policies: | ||||||
|             bundle.entries.extend(self.walk_policies()) |             entries.extend(self.walk_policies()) | ||||||
|         bundle.entries.extend(self.walk_stages()) |         entries.extend(self.walk_stages()) | ||||||
|         bundle.entries.extend(self.walk_stage_bindings()) |         entries.extend(self.walk_stage_bindings()) | ||||||
|         if self.with_policies: |         if self.with_policies: | ||||||
|             bundle.entries.extend(self.walk_policy_bindings()) |             entries.extend(self.walk_policy_bindings()) | ||||||
|         return bundle |         return entries | ||||||
|  |  | ||||||
|     def export_to_string(self) -> str: |  | ||||||
|         """Call export and convert it to yaml""" |  | ||||||
|         bundle = self.export() |  | ||||||
|         return dump(bundle, Dumper=BlueprintDumper) |  | ||||||
|  | |||||||
| @ -3,10 +3,9 @@ from contextlib import contextmanager | |||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
|  |  | ||||||
| from dacite import from_dict | from dacite.core import from_dict | ||||||
| from dacite.exceptions import DaciteError | from dacite.exceptions import DaciteError | ||||||
| from deepmerge import always_merger | from deepmerge import always_merger | ||||||
| from django.apps import apps |  | ||||||
| from django.db import transaction | from django.db import transaction | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.db.models.query_utils import Q | from django.db.models.query_utils import Q | ||||||
| @ -21,9 +20,11 @@ from yaml import load | |||||||
| from authentik.blueprints.v1.common import ( | from authentik.blueprints.v1.common import ( | ||||||
|     Blueprint, |     Blueprint, | ||||||
|     BlueprintEntry, |     BlueprintEntry, | ||||||
|  |     BlueprintEntryState, | ||||||
|     BlueprintLoader, |     BlueprintLoader, | ||||||
|     EntryInvalidError, |     EntryInvalidError, | ||||||
| ) | ) | ||||||
|  | from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry | ||||||
| from authentik.core.models import ( | from authentik.core.models import ( | ||||||
|     AuthenticatedSession, |     AuthenticatedSession, | ||||||
|     PropertyMapping, |     PropertyMapping, | ||||||
| @ -58,7 +59,7 @@ def is_model_allowed(model: type[Model]) -> bool: | |||||||
|         # Classes that have other dependencies |         # Classes that have other dependencies | ||||||
|         AuthenticatedSession, |         AuthenticatedSession, | ||||||
|     ) |     ) | ||||||
|     return model not in excluded_models |     return model not in excluded_models and issubclass(model, (SerializerModel, BaseMetaModel)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @contextmanager | @contextmanager | ||||||
| @ -137,10 +138,20 @@ class Importer: | |||||||
|     def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer: |     def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer: | ||||||
|         """Validate a single entry""" |         """Validate a single entry""" | ||||||
|         model_app_label, model_name = entry.model.split(".") |         model_app_label, model_name = entry.model.split(".") | ||||||
|         model: type[SerializerModel] = apps.get_model(model_app_label, model_name) |         model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||||
|         # Don't use isinstance since we don't want to check for inheritance |         # Don't use isinstance since we don't want to check for inheritance | ||||||
|         if not is_model_allowed(model): |         if not is_model_allowed(model): | ||||||
|             raise EntryInvalidError(f"Model {model} not allowed") |             raise EntryInvalidError(f"Model {model} not allowed") | ||||||
|  |         if issubclass(model, BaseMetaModel): | ||||||
|  |             serializer_class: type[Serializer] = model.serializer() | ||||||
|  |             serializer = serializer_class(data=entry.get_attrs(self.__import)) | ||||||
|  |             try: | ||||||
|  |                 serializer.is_valid(raise_exception=True) | ||||||
|  |             except ValidationError as exc: | ||||||
|  |                 raise EntryInvalidError( | ||||||
|  |                     f"Serializer errors {serializer.errors}", serializer_errors=serializer.errors | ||||||
|  |                 ) from exc | ||||||
|  |             return serializer | ||||||
|         if entry.identifiers == {}: |         if entry.identifiers == {}: | ||||||
|             raise EntryInvalidError("No identifiers") |             raise EntryInvalidError("No identifiers") | ||||||
|  |  | ||||||
| @ -157,7 +168,7 @@ class Importer: | |||||||
|         existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers)) |         existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers)) | ||||||
|  |  | ||||||
|         serializer_kwargs = {} |         serializer_kwargs = {} | ||||||
|         if existing_models.exists(): |         if not isinstance(model(), BaseMetaModel) and existing_models.exists(): | ||||||
|             model_instance = existing_models.first() |             model_instance = existing_models.first() | ||||||
|             self.logger.debug( |             self.logger.debug( | ||||||
|                 "initialise serializer with instance", |                 "initialise serializer with instance", | ||||||
| @ -168,7 +179,9 @@ class Importer: | |||||||
|             serializer_kwargs["instance"] = model_instance |             serializer_kwargs["instance"] = model_instance | ||||||
|             serializer_kwargs["partial"] = True |             serializer_kwargs["partial"] = True | ||||||
|         else: |         else: | ||||||
|             self.logger.debug("initialise new instance", model=model, **updated_identifiers) |             self.logger.debug( | ||||||
|  |                 "initialised new serializer instance", model=model, **updated_identifiers | ||||||
|  |             ) | ||||||
|             model_instance = model() |             model_instance = model() | ||||||
|             # pk needs to be set on the model instance otherwise a new one will be generated |             # pk needs to be set on the model instance otherwise a new one will be generated | ||||||
|             if "pk" in updated_identifiers: |             if "pk" in updated_identifiers: | ||||||
| @ -182,7 +195,9 @@ class Importer: | |||||||
|         try: |         try: | ||||||
|             serializer.is_valid(raise_exception=True) |             serializer.is_valid(raise_exception=True) | ||||||
|         except ValidationError as exc: |         except ValidationError as exc: | ||||||
|             raise EntryInvalidError(f"Serializer errors {serializer.errors}") from exc |             raise EntryInvalidError( | ||||||
|  |                 f"Serializer errors {serializer.errors}", serializer_errors=serializer.errors | ||||||
|  |             ) from exc | ||||||
|         return serializer |         return serializer | ||||||
|  |  | ||||||
|     def apply(self) -> bool: |     def apply(self) -> bool: | ||||||
| @ -204,7 +219,7 @@ class Importer: | |||||||
|         for entry in self.__import.entries: |         for entry in self.__import.entries: | ||||||
|             model_app_label, model_name = entry.model.split(".") |             model_app_label, model_name = entry.model.split(".") | ||||||
|             try: |             try: | ||||||
|                 model: SerializerModel = apps.get_model(model_app_label, model_name) |                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||||
|             except LookupError: |             except LookupError: | ||||||
|                 self.logger.warning( |                 self.logger.warning( | ||||||
|                     "app or model does not exist", app=model_app_label, model=model_name |                     "app or model does not exist", app=model_app_label, model=model_name | ||||||
| @ -214,14 +229,14 @@ class Importer: | |||||||
|             try: |             try: | ||||||
|                 serializer = self._validate_single(entry) |                 serializer = self._validate_single(entry) | ||||||
|             except EntryInvalidError as exc: |             except EntryInvalidError as exc: | ||||||
|                 self.logger.warning("entry invalid", entry=entry, error=exc) |                 self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) | ||||||
|                 return False |                 return False | ||||||
|  |  | ||||||
|             model = serializer.save() |             model = serializer.save() | ||||||
|             if "pk" in entry.identifiers: |             if "pk" in entry.identifiers: | ||||||
|                 self.__pk_map[entry.identifiers["pk"]] = model.pk |                 self.__pk_map[entry.identifiers["pk"]] = model.pk | ||||||
|             entry._instance = model |             entry._state = BlueprintEntryState(model) | ||||||
|             self.logger.debug("updated model", model=model, pk=model.pk) |             self.logger.debug("updated model", model=model) | ||||||
|         return True |         return True | ||||||
|  |  | ||||||
|     def validate(self) -> tuple[bool, list[EventDict]]: |     def validate(self) -> tuple[bool, list[EventDict]]: | ||||||
| @ -230,7 +245,7 @@ class Importer: | |||||||
|         self.logger.debug("Starting blueprint import validation") |         self.logger.debug("Starting blueprint import validation") | ||||||
|         orig_import = deepcopy(self.__import) |         orig_import = deepcopy(self.__import) | ||||||
|         if self.__import.version != 1: |         if self.__import.version != 1: | ||||||
|             self.logger.warning("Invalid bundle version") |             self.logger.warning("Invalid blueprint version") | ||||||
|             return False, [] |             return False, [] | ||||||
|         with ( |         with ( | ||||||
|             transaction_rollback(), |             transaction_rollback(), | ||||||
| @ -238,8 +253,8 @@ class Importer: | |||||||
|         ): |         ): | ||||||
|             successful = self._apply_models() |             successful = self._apply_models() | ||||||
|             if not successful: |             if not successful: | ||||||
|                 self.logger.debug("blueprint validation failed") |                 self.logger.debug("Blueprint validation failed") | ||||||
|         for log in logs: |         for log in logs: | ||||||
|             self.logger.debug(**log) |             getattr(self.logger, log.get("log_level"))(**log) | ||||||
|         self.__import = orig_import |         self.__import = orig_import | ||||||
|         return successful, logs |         return successful, logs | ||||||
|  | |||||||
| @ -2,3 +2,4 @@ | |||||||
|  |  | ||||||
| LABEL_AUTHENTIK_SYSTEM = "blueprints.goauthentik.io/system" | LABEL_AUTHENTIK_SYSTEM = "blueprints.goauthentik.io/system" | ||||||
| LABEL_AUTHENTIK_INSTANTIATE = "blueprints.goauthentik.io/instantiate" | LABEL_AUTHENTIK_INSTANTIATE = "blueprints.goauthentik.io/instantiate" | ||||||
|  | LABEL_AUTHENTIK_GENERATED = "blueprints.goauthentik.io/generated" | ||||||
|  | |||||||
							
								
								
									
										0
									
								
								authentik/blueprints/v1/meta/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								authentik/blueprints/v1/meta/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										60
									
								
								authentik/blueprints/v1/meta/apply_blueprint.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								authentik/blueprints/v1/meta/apply_blueprint.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | |||||||
|  | """Apply Blueprint meta model""" | ||||||
|  | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
|  | from rest_framework.exceptions import ValidationError | ||||||
|  | from rest_framework.fields import BooleanField, JSONField | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.blueprints.v1.meta.registry import BaseMetaModel, MetaResult, registry | ||||||
|  | from authentik.core.api.utils import PassiveSerializer, is_dict | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from authentik.blueprints.models import BlueprintInstance | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ApplyBlueprintMetaSerializer(PassiveSerializer): | ||||||
|  |     """Serializer for meta apply blueprint model""" | ||||||
|  |  | ||||||
|  |     identifiers = JSONField(validators=[is_dict]) | ||||||
|  |     required = BooleanField(default=True) | ||||||
|  |  | ||||||
|  |     # We cannot override `instance` as that will confuse rest_framework | ||||||
|  |     # and make it attempt to update the instance | ||||||
|  |     blueprint_instance: "BlueprintInstance" | ||||||
|  |  | ||||||
|  |     def validate(self, attrs): | ||||||
|  |         from authentik.blueprints.models import BlueprintInstance | ||||||
|  |  | ||||||
|  |         identifiers = attrs["identifiers"] | ||||||
|  |         required = attrs["required"] | ||||||
|  |         instance = BlueprintInstance.objects.filter(**identifiers).first() | ||||||
|  |         if not instance and required: | ||||||
|  |             raise ValidationError("Required blueprint does not exist") | ||||||
|  |         self.blueprint_instance = instance | ||||||
|  |         return super().validate(attrs) | ||||||
|  |  | ||||||
|  |     def create(self, validated_data: dict) -> MetaResult: | ||||||
|  |         from authentik.blueprints.v1.tasks import apply_blueprint | ||||||
|  |  | ||||||
|  |         if not self.blueprint_instance: | ||||||
|  |             LOGGER.info("Blueprint does not exist, but not required") | ||||||
|  |             return MetaResult() | ||||||
|  |         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) | ||||||
|  |         # pylint: disable=no-value-for-parameter | ||||||
|  |         apply_blueprint(str(self.blueprint_instance.pk)) | ||||||
|  |         return MetaResult() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @registry.register("metaapplyblueprint") | ||||||
|  | class MetaApplyBlueprint(BaseMetaModel): | ||||||
|  |     """Meta model to apply another blueprint""" | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def serializer() -> ApplyBlueprintMetaSerializer: | ||||||
|  |         return ApplyBlueprintMetaSerializer | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |  | ||||||
|  |         abstract = True | ||||||
							
								
								
									
										61
									
								
								authentik/blueprints/v1/meta/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								authentik/blueprints/v1/meta/registry.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,61 @@ | |||||||
|  | """Base models""" | ||||||
|  | from django.apps import apps | ||||||
|  | from django.db.models import Model | ||||||
|  | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BaseMetaModel(Model): | ||||||
|  |     """Base models""" | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def serializer() -> Serializer: | ||||||
|  |         """Serializer similar to SerializerModel, but as a static method since | ||||||
|  |         this is an abstract model""" | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |  | ||||||
|  |         abstract = True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MetaResult: | ||||||
|  |     """Result returned by Meta Models' serializers. Empty class but we can't return none as | ||||||
|  |     the framework doesn't allow that""" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MetaModelRegistry: | ||||||
|  |     """Registry for pseudo meta models""" | ||||||
|  |  | ||||||
|  |     models: dict[str, BaseMetaModel] | ||||||
|  |     virtual_prefix: str | ||||||
|  |  | ||||||
|  |     def __init__(self, prefix: str) -> None: | ||||||
|  |         self.models = {} | ||||||
|  |         self.virtual_prefix = prefix | ||||||
|  |  | ||||||
|  |     def register(self, model_id: str): | ||||||
|  |         """Register model class under `model_id`""" | ||||||
|  |  | ||||||
|  |         def inner_wrapper(cls): | ||||||
|  |             self.models[model_id] = cls | ||||||
|  |             return cls | ||||||
|  |  | ||||||
|  |         return inner_wrapper | ||||||
|  |  | ||||||
|  |     def get_models(self): | ||||||
|  |         """Wrapper for django's `get_models` to list all models""" | ||||||
|  |         models = apps.get_models() | ||||||
|  |         for _, value in self.models.items(): | ||||||
|  |             models.append(value) | ||||||
|  |         return models | ||||||
|  |  | ||||||
|  |     def get_model(self, app_label: str, model_id: str) -> type[Model]: | ||||||
|  |         """Get model checks if any virtual models are registered, and falls back | ||||||
|  |         to actual django models""" | ||||||
|  |         if app_label.lower() == self.virtual_prefix: | ||||||
|  |             if model_id.lower() in self.models: | ||||||
|  |                 return self.models[model_id] | ||||||
|  |         return apps.get_model(app_label, model_id) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | registry = MetaModelRegistry("authentik_blueprints") | ||||||
| @ -4,15 +4,21 @@ from hashlib import sha512 | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| from dacite import from_dict | from dacite.core import from_dict | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
|  | from django.utils.text import slugify | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from structlog.stdlib import get_logger | ||||||
| from yaml import load | from yaml import load | ||||||
| from yaml.error import YAMLError | from yaml.error import YAMLError | ||||||
|  |  | ||||||
| from authentik.blueprints.models import BlueprintInstance, BlueprintInstanceStatus | from authentik.blueprints.models import ( | ||||||
| from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata |     BlueprintInstance, | ||||||
|  |     BlueprintInstanceStatus, | ||||||
|  |     BlueprintRetrievalFailed, | ||||||
|  | ) | ||||||
|  | from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, EntryInvalidError | ||||||
| from authentik.blueprints.v1.importer import Importer | from authentik.blueprints.v1.importer import Importer | ||||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | ||||||
| from authentik.events.monitored_tasks import ( | from authentik.events.monitored_tasks import ( | ||||||
| @ -21,9 +27,12 @@ from authentik.events.monitored_tasks import ( | |||||||
|     TaskResultStatus, |     TaskResultStatus, | ||||||
|     prefill_task, |     prefill_task, | ||||||
| ) | ) | ||||||
|  | from authentik.events.utils import sanitize_dict | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class BlueprintFile: | class BlueprintFile: | ||||||
| @ -39,27 +48,45 @@ class BlueprintFile: | |||||||
| @CELERY_APP.task( | @CELERY_APP.task( | ||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), |     throws=(DatabaseError, ProgrammingError, InternalError), | ||||||
| ) | ) | ||||||
|  | def blueprints_find_dict(): | ||||||
|  |     """Find blueprints as `blueprints_find` does, but return a safe dict""" | ||||||
|  |     blueprints = [] | ||||||
|  |     for blueprint in blueprints_find(): | ||||||
|  |         blueprints.append(sanitize_dict(asdict(blueprint))) | ||||||
|  |     return blueprints | ||||||
|  |  | ||||||
|  |  | ||||||
| def blueprints_find(): | def blueprints_find(): | ||||||
|     """Find blueprints and return valid ones""" |     """Find blueprints and return valid ones""" | ||||||
|     blueprints = [] |     blueprints = [] | ||||||
|     root = Path(CONFIG.y("blueprints_dir")) |     root = Path(CONFIG.y("blueprints_dir")) | ||||||
|     for file in root.glob("**/*.yaml"): |     for file in root.glob("**/*.yaml"): | ||||||
|         path = Path(file) |         path = Path(file) | ||||||
|  |         LOGGER.debug("found blueprint", path=str(path)) | ||||||
|         with open(path, "r", encoding="utf-8") as blueprint_file: |         with open(path, "r", encoding="utf-8") as blueprint_file: | ||||||
|             try: |             try: | ||||||
|                 raw_blueprint = load(blueprint_file.read(), BlueprintLoader) |                 raw_blueprint = load(blueprint_file.read(), BlueprintLoader) | ||||||
|             except YAMLError: |             except YAMLError as exc: | ||||||
|                 raw_blueprint = None |                 raw_blueprint = None | ||||||
|  |                 LOGGER.warning("failed to parse blueprint", exc=exc, path=str(path)) | ||||||
|             if not raw_blueprint: |             if not raw_blueprint: | ||||||
|                 continue |                 continue | ||||||
|             metadata = raw_blueprint.get("metadata", None) |             metadata = raw_blueprint.get("metadata", None) | ||||||
|             version = raw_blueprint.get("version", 1) |             version = raw_blueprint.get("version", 1) | ||||||
|             if version != 1: |             if version != 1: | ||||||
|  |                 LOGGER.warning("invalid blueprint version", version=version, path=str(path)) | ||||||
|                 continue |                 continue | ||||||
|         file_hash = sha512(path.read_bytes()).hexdigest() |         file_hash = sha512(path.read_bytes()).hexdigest() | ||||||
|         blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime) |         blueprint = BlueprintFile( | ||||||
|  |             str(path.relative_to(root)), version, file_hash, int(path.stat().st_mtime) | ||||||
|  |         ) | ||||||
|         blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None |         blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None | ||||||
|         blueprints.append(blueprint) |         blueprints.append(blueprint) | ||||||
|  |         LOGGER.info( | ||||||
|  |             "parsed & loaded blueprint", | ||||||
|  |             hash=file_hash, | ||||||
|  |             path=str(path), | ||||||
|  |         ) | ||||||
|     return blueprints |     return blueprints | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -101,9 +128,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | |||||||
|         ) |         ) | ||||||
|         instance.save() |         instance.save() | ||||||
|     if instance.last_applied_hash != blueprint.hash: |     if instance.last_applied_hash != blueprint.hash: | ||||||
|         instance.metadata = asdict(blueprint.meta) if blueprint.meta else {} |         apply_blueprint.delay(str(instance.pk)) | ||||||
|         instance.save() |  | ||||||
|         apply_blueprint.delay(instance.pk.hex) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @CELERY_APP.task( | ||||||
| @ -112,16 +137,18 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | |||||||
| ) | ) | ||||||
| def apply_blueprint(self: MonitoredTask, instance_pk: str): | def apply_blueprint(self: MonitoredTask, instance_pk: str): | ||||||
|     """Apply single blueprint""" |     """Apply single blueprint""" | ||||||
|     self.set_uid(instance_pk) |  | ||||||
|     self.save_on_success = False |     self.save_on_success = False | ||||||
|  |     instance: Optional[BlueprintInstance] = None | ||||||
|     try: |     try: | ||||||
|         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() |         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() | ||||||
|  |         self.set_uid(slugify(instance.name)) | ||||||
|         if not instance or not instance.enabled: |         if not instance or not instance.enabled: | ||||||
|             return |             return | ||||||
|         full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(instance.path)) |         blueprint_content = instance.retrieve() | ||||||
|         file_hash = sha512(full_path.read_bytes()).hexdigest() |         file_hash = sha512(blueprint_content.encode()).hexdigest() | ||||||
|         with open(full_path, "r", encoding="utf-8") as blueprint_file: |         importer = Importer(blueprint_content, instance.context) | ||||||
|             importer = Importer(blueprint_file.read(), instance.context) |         if importer.blueprint.metadata: | ||||||
|  |             instance.metadata = asdict(importer.blueprint.metadata) | ||||||
|         valid, logs = importer.validate() |         valid, logs = importer.validate() | ||||||
|         if not valid: |         if not valid: | ||||||
|             instance.status = BlueprintInstanceStatus.ERROR |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
| @ -137,9 +164,18 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str): | |||||||
|         instance.status = BlueprintInstanceStatus.SUCCESSFUL |         instance.status = BlueprintInstanceStatus.SUCCESSFUL | ||||||
|         instance.last_applied_hash = file_hash |         instance.last_applied_hash = file_hash | ||||||
|         instance.last_applied = now() |         instance.last_applied = now() | ||||||
|         instance.save() |  | ||||||
|         self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) |         self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) | ||||||
|     except (DatabaseError, ProgrammingError, InternalError, IOError) as exc: |     except ( | ||||||
|  |         DatabaseError, | ||||||
|  |         ProgrammingError, | ||||||
|  |         InternalError, | ||||||
|  |         IOError, | ||||||
|  |         BlueprintRetrievalFailed, | ||||||
|  |         EntryInvalidError, | ||||||
|  |     ) as exc: | ||||||
|  |         if instance: | ||||||
|             instance.status = BlueprintInstanceStatus.ERROR |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
|         instance.save() |  | ||||||
|         self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc)) |         self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc)) | ||||||
|  |     finally: | ||||||
|  |         if instance: | ||||||
|  |             instance.save() | ||||||
|  | |||||||
| @ -50,6 +50,8 @@ class ApplicationSerializer(ModelSerializer): | |||||||
|  |  | ||||||
|     def get_launch_url(self, app: Application) -> Optional[str]: |     def get_launch_url(self, app: Application) -> Optional[str]: | ||||||
|         """Allow formatting of launch URL""" |         """Allow formatting of launch URL""" | ||||||
|  |         user = None | ||||||
|  |         if "request" in self.context: | ||||||
|             user = self.context["request"].user |             user = self.context["request"].user | ||||||
|         return app.get_launch_url(user) |         return app.get_launch_url(user) | ||||||
|  |  | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ from authentik.api.decorators import permission_required | |||||||
| from authentik.blueprints.api import ManagedSerializer | from authentik.blueprints.api import ManagedSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | ||||||
| from authentik.core.expression import PropertyMappingEvaluator | from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||||
| from authentik.core.models import PropertyMapping | from authentik.core.models import PropertyMapping | ||||||
| from authentik.lib.utils.reflection import all_subclasses | from authentik.lib.utils.reflection import all_subclasses | ||||||
| from authentik.policies.api.exec import PolicyTestSerializer | from authentik.policies.api.exec import PolicyTestSerializer | ||||||
| @ -41,7 +41,9 @@ class PropertyMappingSerializer(ManagedSerializer, ModelSerializer, MetaNameSeri | |||||||
|  |  | ||||||
|     def validate_expression(self, expression: str) -> str: |     def validate_expression(self, expression: str) -> str: | ||||||
|         """Test Syntax""" |         """Test Syntax""" | ||||||
|         evaluator = PropertyMappingEvaluator() |         evaluator = PropertyMappingEvaluator( | ||||||
|  |             self.instance, | ||||||
|  |         ) | ||||||
|         evaluator.validate(expression) |         evaluator.validate(expression) | ||||||
|         return expression |         return expression | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """authentik core app config""" | """authentik core app config""" | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikCoreConfig(ManagedAppConfig): | class AuthentikCoreConfig(ManagedAppConfig): | ||||||
|  | |||||||
							
								
								
									
										0
									
								
								authentik/core/expression/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								authentik/core/expression/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -2,28 +2,33 @@ | |||||||
| from traceback import format_tb | from traceback import format_tb | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
|  | from django.db.models import Model | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from guardian.utils import get_anonymous_user | from guardian.utils import get_anonymous_user | ||||||
| 
 | 
 | ||||||
| from authentik.core.models import PropertyMapping, User | from authentik.core.models import User | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.expression.evaluator import BaseEvaluator | from authentik.lib.expression.evaluator import BaseEvaluator | ||||||
| from authentik.policies.types import PolicyRequest | from authentik.policies.types import PolicyRequest | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class PropertyMappingEvaluator(BaseEvaluator): | class PropertyMappingEvaluator(BaseEvaluator): | ||||||
|     """Custom Evalautor that adds some different context variables.""" |     """Custom Evaluator that adds some different context variables.""" | ||||||
| 
 | 
 | ||||||
|     def set_context( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         user: Optional[User], |         model: Model, | ||||||
|         request: Optional[HttpRequest], |         user: Optional[User] = None, | ||||||
|         mapping: PropertyMapping, |         request: Optional[HttpRequest] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         """Update context with context from PropertyMapping's evaluate""" |         if hasattr(model, "name"): | ||||||
|  |             _filename = model.name | ||||||
|  |         else: | ||||||
|  |             _filename = str(model) | ||||||
|  |         super().__init__(filename=_filename) | ||||||
|         req = PolicyRequest(user=get_anonymous_user()) |         req = PolicyRequest(user=get_anonymous_user()) | ||||||
|         req.obj = mapping |         req.obj = model | ||||||
|         if user: |         if user: | ||||||
|             req.user = user |             req.user = user | ||||||
|             self._context["user"] = user |             self._context["user"] = user | ||||||
| @ -7,9 +7,9 @@ from django.core.management.base import BaseCommand | |||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.db.models.signals import post_save, pre_delete | from django.db.models.signals import post_save, pre_delete | ||||||
|  |  | ||||||
| from authentik import __version__ | from authentik import get_full_version | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.events.middleware import IGNORED_MODELS | from authentik.events.middleware import should_log_model | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.utils import model_to_dict | from authentik.events.utils import model_to_dict | ||||||
|  |  | ||||||
| @ -18,7 +18,7 @@ BANNER_TEXT = """### authentik shell ({authentik}) | |||||||
|     node=platform.node(), |     node=platform.node(), | ||||||
|     python=platform.python_version(), |     python=platform.python_version(), | ||||||
|     arch=platform.machine(), |     arch=platform.machine(), | ||||||
|     authentik=__version__, |     authentik=get_full_version(), | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -50,7 +50,7 @@ class Command(BaseCommand): | |||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
|     def post_save_handler(sender, instance: Model, created: bool, **_): |     def post_save_handler(sender, instance: Model, created: bool, **_): | ||||||
|         """Signal handler for all object's post_save""" |         """Signal handler for all object's post_save""" | ||||||
|         if isinstance(instance, IGNORED_MODELS): |         if not should_log_model(instance): | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED |         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||||
| @ -66,7 +66,7 @@ class Command(BaseCommand): | |||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
|     def pre_delete_handler(sender, instance: Model, **_): |     def pre_delete_handler(sender, instance: Model, **_): | ||||||
|         """Signal handler for all object's pre_delete""" |         """Signal handler for all object's pre_delete""" | ||||||
|         if isinstance(instance, IGNORED_MODELS):  # pragma: no cover |         if not should_log_model(instance):  # pragma: no cover | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         Event.new(EventAction.MODEL_DELETED, model=model_to_dict(instance)).set_user( |         Event.new(EventAction.MODEL_DELETED, model=model_to_dict(instance)).set_user( | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik admin Middleware to impersonate users""" | """authentik admin Middleware to impersonate users""" | ||||||
| from contextvars import ContextVar | from contextvars import ContextVar | ||||||
| from typing import Callable | from typing import Callable, Optional | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| @ -13,9 +13,9 @@ RESPONSE_HEADER_ID = "X-authentik-id" | |||||||
| KEY_AUTH_VIA = "auth_via" | KEY_AUTH_VIA = "auth_via" | ||||||
| KEY_USER = "user" | KEY_USER = "user" | ||||||
|  |  | ||||||
| CTX_REQUEST_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "request_id", default=None) | CTX_REQUEST_ID = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "request_id", default=None) | ||||||
| CTX_HOST = ContextVar(STRUCTLOG_KEY_PREFIX + "host", default=None) | CTX_HOST = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "host", default=None) | ||||||
| CTX_AUTH_VIA = ContextVar(STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | CTX_AUTH_VIA = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ImpersonateMiddleware: | class ImpersonateMiddleware: | ||||||
|  | |||||||
| @ -617,10 +617,9 @@ class PropertyMapping(SerializerModel, ManagedModel): | |||||||
|  |  | ||||||
|     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: |     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: | ||||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" |         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||||
|         from authentik.core.expression import PropertyMappingEvaluator |         from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||||
|  |  | ||||||
|         evaluator = PropertyMappingEvaluator() |         evaluator = PropertyMappingEvaluator(self, user, request, **kwargs) | ||||||
|         evaluator.set_context(user, request, self, **kwargs) |  | ||||||
|         try: |         try: | ||||||
|             return evaluator.evaluate(self.expression) |             return evaluator.evaluate(self.expression) | ||||||
|         except Exception as exc: |         except Exception as exc: | ||||||
|  | |||||||
| @ -1,31 +0,0 @@ | |||||||
| {% extends 'base/skeleton.html' %} |  | ||||||
|  |  | ||||||
| {% load i18n %} |  | ||||||
|  |  | ||||||
| {% block head %} |  | ||||||
| {{ block.super }} |  | ||||||
| <style> |  | ||||||
|     .pf-c-empty-state { |  | ||||||
|         height: 100vh; |  | ||||||
|     } |  | ||||||
| </style> |  | ||||||
| {% endblock %} |  | ||||||
|  |  | ||||||
| {% block body %} |  | ||||||
| <section class="ak-static-page pf-c-page__main-section pf-m-no-padding-mobile pf-m-xl"> |  | ||||||
|     <div class="pf-c-empty-state"> |  | ||||||
|         <div class="pf-c-empty-state__content"> |  | ||||||
|             <i class="fas fa-exclamation-circle pf-c-empty-state__icon" aria-hidden="true"></i> |  | ||||||
|             <h1 class="pf-c-title pf-m-lg"> |  | ||||||
|                 {% trans title %} |  | ||||||
|             </h1> |  | ||||||
|             <div class="pf-c-empty-state__body"> |  | ||||||
|                 {% if message %} |  | ||||||
|                 <h3>{% trans message %}</h3> |  | ||||||
|                 {% endif %} |  | ||||||
|             </div> |  | ||||||
|             <a href="/" class="pf-c-button pf-m-primary pf-m-block">{% trans 'Go to home' %}</a> |  | ||||||
|         </div> |  | ||||||
|     </div> |  | ||||||
| </section> |  | ||||||
| {% endblock %} |  | ||||||
| @ -4,14 +4,16 @@ | |||||||
| {% load i18n %} | {% load i18n %} | ||||||
|  |  | ||||||
| {% block head %} | {% block head %} | ||||||
| <script src="{% static 'dist/admin/AdminInterface.js' %}" type="module"></script> | <script src="{% static 'dist/admin/AdminInterface.js' %}?version={{ version }}" type="module"></script> | ||||||
| <meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)"> | <meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)"> | ||||||
| <meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)"> | <meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)"> | ||||||
|  | <link rel="icon" href="{{ tenant.branding_favicon }}"> | ||||||
|  | <link rel="shortcut icon" href="{{ tenant.branding_favicon }}"> | ||||||
| <script> | <script> | ||||||
| window.authentik = {}; | window.authentik = {}; | ||||||
| window.authentik.locale = "{{ tenant.default_locale }}"; | window.authentik.locale = "{{ tenant.default_locale }}"; | ||||||
| window.authentik.config = JSON.parse('{{ config_json|safe }}'); | window.authentik.config = JSON.parse('{{ config_json|escapejs }}'); | ||||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|safe }}'); | window.authentik.tenant = JSON.parse('{{ tenant_json|escapejs }}'); | ||||||
| </script> | </script> | ||||||
| {% endblock %} | {% endblock %} | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										21
									
								
								authentik/core/templates/if/error.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								authentik/core/templates/if/error.html
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | {% extends 'login/base_full.html' %} | ||||||
|  |  | ||||||
|  | {% load static %} | ||||||
|  | {% load i18n %} | ||||||
|  |  | ||||||
|  | {% block title %} | ||||||
|  | {% trans 'End session' %} - {{ tenant.branding_title }} | ||||||
|  | {% endblock %} | ||||||
|  |  | ||||||
|  | {% block card_title %} | ||||||
|  | {% trans title %} | ||||||
|  | {% endblock %} | ||||||
|  |  | ||||||
|  | {% block card %} | ||||||
|  | <form method="POST" class="pf-c-form"> | ||||||
|  |     <p>{% trans message %}</p> | ||||||
|  |     <a id="ak-back-home" href="{% url 'authentik_core:root-redirect' %}" class="pf-c-button pf-m-primary"> | ||||||
|  |         {% trans 'Go home' %} | ||||||
|  |     </a> | ||||||
|  | </form> | ||||||
|  | {% endblock %} | ||||||
| @ -6,14 +6,16 @@ | |||||||
| {% block head_before %} | {% block head_before %} | ||||||
| {{ block.super }} | {{ block.super }} | ||||||
| <link rel="prefetch" href="{{ flow.background_url }}" /> | <link rel="prefetch" href="{{ flow.background_url }}" /> | ||||||
|  | <link rel="icon" href="{{ tenant.branding_favicon }}"> | ||||||
|  | <link rel="shortcut icon" href="{{ tenant.branding_favicon }}"> | ||||||
| {% if flow.compatibility_mode and not inspector %} | {% if flow.compatibility_mode and not inspector %} | ||||||
| <script>ShadyDOM = { force: !navigator.webdriver };</script> | <script>ShadyDOM = { force: !navigator.webdriver };</script> | ||||||
| {% endif %} | {% endif %} | ||||||
| <script> | <script> | ||||||
| window.authentik = {}; | window.authentik = {}; | ||||||
| window.authentik.locale = "{{ tenant.default_locale }}"; | window.authentik.locale = "{{ tenant.default_locale }}"; | ||||||
| window.authentik.config = JSON.parse( '{{ config_json|safe }}'); | window.authentik.config = JSON.parse('{{ config_json|escapejs }}'); | ||||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|safe }}'); | window.authentik.tenant = JSON.parse('{{ tenant_json|escapejs }}'); | ||||||
| window.authentik.flow = { | window.authentik.flow = { | ||||||
|     "layout": "{{ flow.layout }}", |     "layout": "{{ flow.layout }}", | ||||||
| }; | }; | ||||||
| @ -21,7 +23,7 @@ window.authentik.flow = { | |||||||
| {% endblock %} | {% endblock %} | ||||||
|  |  | ||||||
| {% block head %} | {% block head %} | ||||||
| <script src="{% static 'dist/flow/FlowInterface.js' %}" type="module"></script> | <script src="{% static 'dist/flow/FlowInterface.js' %}?version={{ version }}" type="module"></script> | ||||||
| <style> | <style> | ||||||
| :root { | :root { | ||||||
|     --ak-flow-background: url("{{ flow.background_url }}"); |     --ak-flow-background: url("{{ flow.background_url }}"); | ||||||
|  | |||||||
| @ -4,14 +4,16 @@ | |||||||
| {% load i18n %} | {% load i18n %} | ||||||
|  |  | ||||||
| {% block head %} | {% block head %} | ||||||
| <script src="{% static 'dist/user/UserInterface.js' %}" type="module"></script> | <script src="{% static 'dist/user/UserInterface.js' %}?version={{ version }}" type="module"></script> | ||||||
| <meta name="theme-color" content="#151515" media="(prefers-color-scheme: light)"> | <meta name="theme-color" content="#151515" media="(prefers-color-scheme: light)"> | ||||||
| <meta name="theme-color" content="#151515" media="(prefers-color-scheme: dark)"> | <meta name="theme-color" content="#151515" media="(prefers-color-scheme: dark)"> | ||||||
|  | <link rel="icon" href="{{ tenant.branding_favicon }}"> | ||||||
|  | <link rel="shortcut icon" href="{{ tenant.branding_favicon }}"> | ||||||
| <script> | <script> | ||||||
| window.authentik = {}; | window.authentik = {}; | ||||||
| window.authentik.locale = "{{ tenant.default_locale }}"; | window.authentik.locale = "{{ tenant.default_locale }}"; | ||||||
| window.authentik.config = JSON.parse('{{ config_json|safe }}'); | window.authentik.config = JSON.parse('{{ config_json|escapejs }}'); | ||||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|safe }}'); | window.authentik.tenant = JSON.parse('{{ tenant_json|escapejs }}'); | ||||||
| </script> | </script> | ||||||
| {% endblock %} | {% endblock %} | ||||||
|  |  | ||||||
|  | |||||||
| @ -5,8 +5,7 @@ from django.urls import reverse | |||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
| from authentik.flows.models import Flow |  | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider | from authentik.providers.oauth2.models import OAuth2Provider | ||||||
| @ -20,10 +19,7 @@ class TestApplicationsAPI(APITestCase): | |||||||
|         self.provider = OAuth2Provider.objects.create( |         self.provider = OAuth2Provider.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|             redirect_uris="http://some-other-domain", |             redirect_uris="http://some-other-domain", | ||||||
|             authorization_flow=Flow.objects.create( |             authorization_flow=create_test_flow(), | ||||||
|                 name="test", |  | ||||||
|                 slug="test", |  | ||||||
|             ), |  | ||||||
|         ) |         ) | ||||||
|         self.allowed = Application.objects.create( |         self.allowed = Application.objects.create( | ||||||
|             name="allowed", |             name="allowed", | ||||||
|  | |||||||
| @ -4,8 +4,7 @@ from unittest.mock import MagicMock, patch | |||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_tenant | from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant | ||||||
| from authentik.flows.models import Flow, FlowDesignation |  | ||||||
| from authentik.flows.tests import FlowTestCase | from authentik.flows.tests import FlowTestCase | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| @ -21,11 +20,7 @@ class TestApplicationsViews(FlowTestCase): | |||||||
|  |  | ||||||
|     def test_check_redirect(self): |     def test_check_redirect(self): | ||||||
|         """Test redirect""" |         """Test redirect""" | ||||||
|         empty_flow = Flow.objects.create( |         empty_flow = create_test_flow() | ||||||
|             name="foo", |  | ||||||
|             slug="foo", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         tenant: Tenant = create_test_tenant() |         tenant: Tenant = create_test_tenant() | ||||||
|         tenant.flow_authentication = empty_flow |         tenant.flow_authentication = empty_flow | ||||||
|         tenant.save() |         tenant.save() | ||||||
| @ -49,11 +44,7 @@ class TestApplicationsViews(FlowTestCase): | |||||||
|     def test_check_redirect_auth(self): |     def test_check_redirect_auth(self): | ||||||
|         """Test redirect""" |         """Test redirect""" | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|         empty_flow = Flow.objects.create( |         empty_flow = create_test_flow() | ||||||
|             name="foo", |  | ||||||
|             slug="foo", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         tenant: Tenant = create_test_tenant() |         tenant: Tenant = create_test_tenant() | ||||||
|         tenant.flow_authentication = empty_flow |         tenant.flow_authentication = empty_flow | ||||||
|         tenant.save() |         tenant.save() | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ from guardian.utils import get_anonymous_user | |||||||
|  |  | ||||||
| from authentik.core.models import SourceUserMatchingModes, User | from authentik.core.models import SourceUserMatchingModes, User | ||||||
| from authentik.core.sources.flow_manager import Action | from authentik.core.sources.flow_manager import Action | ||||||
| from authentik.flows.models import Flow, FlowDesignation | from authentik.core.tests.utils import create_test_flow | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.tests.utils import get_request | from authentik.lib.tests.utils import get_request | ||||||
| from authentik.policies.denied import AccessDeniedResponse | from authentik.policies.denied import AccessDeniedResponse | ||||||
| @ -152,9 +152,7 @@ class TestSourceFlowManager(TestCase): | |||||||
|         """Test error handling when a source selected flow is non-applicable due to a policy""" |         """Test error handling when a source selected flow is non-applicable due to a policy""" | ||||||
|         self.source.user_matching_mode = SourceUserMatchingModes.USERNAME_LINK |         self.source.user_matching_mode = SourceUserMatchingModes.USERNAME_LINK | ||||||
|  |  | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow() | ||||||
|             name="test", slug="test", title="test", designation=FlowDesignation.ENROLLMENT |  | ||||||
|         ) |  | ||||||
|         policy = ExpressionPolicy.objects.create( |         policy = ExpressionPolicy.objects.create( | ||||||
|             name="false", expression="""ak_message("foo");return False""" |             name="false", expression="""ak_message("foo");return False""" | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -159,7 +159,6 @@ class TestUsersAPI(APITestCase): | |||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse("authentik_api:user-paths"), |             reverse("authentik_api:user-paths"), | ||||||
|         ) |         ) | ||||||
|         print(response.content) |  | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         self.assertJSONEqual(response.content.decode(), {"paths": ["users"]}) |         self.assertJSONEqual(response.content.decode(), {"paths": ["users"]}) | ||||||
|  |  | ||||||
|  | |||||||
| @ -52,5 +52,5 @@ def create_test_cert() -> CertificateKeyPair: | |||||||
|         subject_alt_names=["goauthentik.io"], |         subject_alt_names=["goauthentik.io"], | ||||||
|         validity_days=360, |         validity_days=360, | ||||||
|     ) |     ) | ||||||
|     builder.name = generate_id() |     builder.common_name = generate_id() | ||||||
|     return builder.save() |     return builder.save() | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ class BadRequestView(TemplateView): | |||||||
|     extra_context = {"title": "Bad Request"} |     extra_context = {"title": "Bad Request"} | ||||||
|  |  | ||||||
|     response_class = BadRequestTemplateResponse |     response_class = BadRequestTemplateResponse | ||||||
|     template_name = "error/generic.html" |     template_name = "if/error.html" | ||||||
|  |  | ||||||
|  |  | ||||||
| class ForbiddenView(TemplateView): | class ForbiddenView(TemplateView): | ||||||
| @ -41,7 +41,7 @@ class ForbiddenView(TemplateView): | |||||||
|     extra_context = {"title": "Forbidden"} |     extra_context = {"title": "Forbidden"} | ||||||
|  |  | ||||||
|     response_class = ForbiddenTemplateResponse |     response_class = ForbiddenTemplateResponse | ||||||
|     template_name = "error/generic.html" |     template_name = "if/error.html" | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotFoundView(TemplateView): | class NotFoundView(TemplateView): | ||||||
| @ -50,7 +50,7 @@ class NotFoundView(TemplateView): | |||||||
|     extra_context = {"title": "Not Found"} |     extra_context = {"title": "Not Found"} | ||||||
|  |  | ||||||
|     response_class = NotFoundTemplateResponse |     response_class = NotFoundTemplateResponse | ||||||
|     template_name = "error/generic.html" |     template_name = "if/error.html" | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServerErrorView(TemplateView): | class ServerErrorView(TemplateView): | ||||||
| @ -59,7 +59,7 @@ class ServerErrorView(TemplateView): | |||||||
|     extra_context = {"title": "Server Error"} |     extra_context = {"title": "Server Error"} | ||||||
|  |  | ||||||
|     response_class = ServerErrorTemplateResponse |     response_class = ServerErrorTemplateResponse | ||||||
|     template_name = "error/generic.html" |     template_name = "if/error.html" | ||||||
|  |  | ||||||
|     # pylint: disable=useless-super-delegation |     # pylint: disable=useless-super-delegation | ||||||
|     def dispatch(self, *args, **kwargs):  # pragma: no cover |     def dispatch(self, *args, **kwargs):  # pragma: no cover | ||||||
|  | |||||||
| @ -12,10 +12,11 @@ from django_filters.filters import BooleanFilter | |||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
|  | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField | from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.serializers import ModelSerializer, ValidationError | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING, Optional | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|  | |||||||
| @ -26,7 +26,7 @@ class CertificateBuilder: | |||||||
|         self.common_name = "authentik Self-signed Certificate" |         self.common_name = "authentik Self-signed Certificate" | ||||||
|         self.cert = CertificateKeyPair() |         self.cert = CertificateKeyPair() | ||||||
|  |  | ||||||
|     def save(self) -> Optional[CertificateKeyPair]: |     def save(self) -> CertificateKeyPair: | ||||||
|         """Save generated certificate as model""" |         """Save generated certificate as model""" | ||||||
|         if not self.__certificate: |         if not self.__certificate: | ||||||
|             raise ValueError("Certificated hasn't been built yet") |             raise ValueError("Certificated hasn't been built yet") | ||||||
|  | |||||||
							
								
								
									
										51
									
								
								authentik/crypto/management/commands/import_certificate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								authentik/crypto/management/commands/import_certificate.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,51 @@ | |||||||
|  | """Import certificate""" | ||||||
|  | from sys import exit as sys_exit | ||||||
|  |  | ||||||
|  | from django.core.management.base import BaseCommand, no_translations | ||||||
|  | from rest_framework.exceptions import ValidationError | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.crypto.api import CertificateKeyPairSerializer | ||||||
|  | from authentik.crypto.models import CertificateKeyPair | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Command(BaseCommand): | ||||||
|  |     """Import certificate""" | ||||||
|  |  | ||||||
|  |     @no_translations | ||||||
|  |     def handle(self, *args, **options): | ||||||
|  |         """Import certificate""" | ||||||
|  |         keypair = CertificateKeyPair.objects.filter(name=options["name"]).first() | ||||||
|  |         dirty = False | ||||||
|  |         if not keypair: | ||||||
|  |             keypair = CertificateKeyPair(name=options["name"]) | ||||||
|  |             dirty = True | ||||||
|  |         with open(options["certificate"], mode="r", encoding="utf-8") as _cert: | ||||||
|  |             cert_data = _cert.read() | ||||||
|  |             if keypair.certificate_data != cert_data: | ||||||
|  |                 dirty = True | ||||||
|  |             keypair.certificate_data = cert_data | ||||||
|  |         if options["private_key"]: | ||||||
|  |             with open(options["private_key"], mode="r", encoding="utf-8") as _key: | ||||||
|  |                 key_data = _key.read() | ||||||
|  |                 if keypair.key_data != key_data: | ||||||
|  |                     dirty = True | ||||||
|  |                 keypair.key_data = key_data | ||||||
|  |         # Validate that cert and key are actually PEM and valid | ||||||
|  |         serializer = CertificateKeyPairSerializer(instance=keypair) | ||||||
|  |         try: | ||||||
|  |             serializer.validate_certificate_data(keypair.certificate_data) | ||||||
|  |             if keypair.key_data != "": | ||||||
|  |                 serializer.validate_certificate_data(keypair.key_data) | ||||||
|  |         except ValidationError as exc: | ||||||
|  |             self.stderr.write(exc) | ||||||
|  |             sys_exit(1) | ||||||
|  |         if dirty: | ||||||
|  |             keypair.save() | ||||||
|  |  | ||||||
|  |     def add_arguments(self, parser): | ||||||
|  |         parser.add_argument("--certificate", type=str, required=True) | ||||||
|  |         parser.add_argument("--private-key", type=str, required=False) | ||||||
|  |         parser.add_argument("--name", type=str, required=True) | ||||||
| @ -6,12 +6,7 @@ from uuid import uuid4 | |||||||
|  |  | ||||||
| from cryptography.hazmat.backends import default_backend | from cryptography.hazmat.backends import default_backend | ||||||
| from cryptography.hazmat.primitives import hashes | from cryptography.hazmat.primitives import hashes | ||||||
| from cryptography.hazmat.primitives.asymmetric.ec import ( | from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES, PUBLIC_KEY_TYPES | ||||||
|     EllipticCurvePrivateKey, |  | ||||||
|     EllipticCurvePublicKey, |  | ||||||
| ) |  | ||||||
| from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey |  | ||||||
| from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey |  | ||||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||||
| from cryptography.x509 import Certificate, load_pem_x509_certificate | from cryptography.x509 import Certificate, load_pem_x509_certificate | ||||||
| from django.db import models | from django.db import models | ||||||
| @ -42,8 +37,8 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     _cert: Optional[Certificate] = None |     _cert: Optional[Certificate] = None | ||||||
|     _private_key: Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey] = None |     _private_key: Optional[PRIVATE_KEY_TYPES] = None | ||||||
|     _public_key: Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey] = None |     _public_key: Optional[PUBLIC_KEY_TYPES] = None | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> Serializer: |     def serializer(self) -> Serializer: | ||||||
| @ -61,7 +56,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|         return self._cert |         return self._cert | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def public_key(self) -> Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey]: |     def public_key(self) -> Optional[PUBLIC_KEY_TYPES]: | ||||||
|         """Get public key of the private key""" |         """Get public key of the private key""" | ||||||
|         if not self._public_key: |         if not self._public_key: | ||||||
|             self._public_key = self.private_key.public_key() |             self._public_key = self.private_key.public_key() | ||||||
| @ -70,7 +65,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|     @property |     @property | ||||||
|     def private_key( |     def private_key( | ||||||
|         self, |         self, | ||||||
|     ) -> Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey]: |     ) -> Optional[PRIVATE_KEY_TYPES]: | ||||||
|         """Get python cryptography PrivateKey instance""" |         """Get python cryptography PrivateKey instance""" | ||||||
|         if not self._private_key and self.key_data != "": |         if not self._private_key and self.key_data != "": | ||||||
|             try: |             try: | ||||||
|  | |||||||
| @ -85,16 +85,18 @@ class NotificationTransportViewSet(UsedByMixin, ModelViewSet): | |||||||
|         """Send example notification using selected transport. Requires |         """Send example notification using selected transport. Requires | ||||||
|         Modify permissions.""" |         Modify permissions.""" | ||||||
|         transport: NotificationTransport = self.get_object() |         transport: NotificationTransport = self.get_object() | ||||||
|  |         event = Event.new( | ||||||
|  |             action="notification_test", | ||||||
|  |             user=get_user(request.user), | ||||||
|  |             app=self.__class__.__module__, | ||||||
|  |             context={"foo": "bar"}, | ||||||
|  |         ) | ||||||
|  |         event.save() | ||||||
|         notification = Notification( |         notification = Notification( | ||||||
|             severity=NotificationSeverity.NOTICE, |             severity=NotificationSeverity.NOTICE, | ||||||
|             body=f"Test Notification from transport {transport.name}", |             body=f"Test Notification from transport {transport.name}", | ||||||
|             user=request.user, |             user=request.user, | ||||||
|             event=Event( |             event=event, | ||||||
|                 action="Test", |  | ||||||
|                 user=get_user(request.user), |  | ||||||
|                 app=self.__class__.__module__, |  | ||||||
|                 context={"foo": "bar"}, |  | ||||||
|             ), |  | ||||||
|         ) |         ) | ||||||
|         try: |         try: | ||||||
|             response = NotificationTransportTestSerializer( |             response = NotificationTransportTestSerializer( | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """authentik events app""" | """authentik events app""" | ||||||
| from prometheus_client import Gauge | from prometheus_client import Gauge | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
| GAUGE_TASKS = Gauge( | GAUGE_TASKS = Gauge( | ||||||
|     "authentik_system_tasks", |     "authentik_system_tasks", | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ from authentik.flows.models import FlowToken | |||||||
| from authentik.lib.sentry import before_send | from authentik.lib.sentry import before_send | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
|  |  | ||||||
| IGNORED_MODELS = [ | IGNORED_MODELS = ( | ||||||
|     Event, |     Event, | ||||||
|     Notification, |     Notification, | ||||||
|     UserObjectPermission, |     UserObjectPermission, | ||||||
| @ -27,12 +27,14 @@ IGNORED_MODELS = [ | |||||||
|     StaticToken, |     StaticToken, | ||||||
|     Session, |     Session, | ||||||
|     FlowToken, |     FlowToken, | ||||||
| ] | ) | ||||||
| if settings.DEBUG: |  | ||||||
|     from silk.models import Request, Response, SQLQuery |  | ||||||
|  |  | ||||||
|     IGNORED_MODELS += [Request, Response, SQLQuery] |  | ||||||
| IGNORED_MODELS = tuple(IGNORED_MODELS) | def should_log_model(model: Model) -> bool: | ||||||
|  |     """Return true if operation on `model` should be logged""" | ||||||
|  |     if model.__module__.startswith("silk"): | ||||||
|  |         return False | ||||||
|  |     return not isinstance(model, IGNORED_MODELS) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuditMiddleware: | class AuditMiddleware: | ||||||
| @ -109,7 +111,7 @@ class AuditMiddleware: | |||||||
|         user: User, request: HttpRequest, sender, instance: Model, created: bool, **_ |         user: User, request: HttpRequest, sender, instance: Model, created: bool, **_ | ||||||
|     ): |     ): | ||||||
|         """Signal handler for all object's post_save""" |         """Signal handler for all object's post_save""" | ||||||
|         if isinstance(instance, IGNORED_MODELS): |         if not should_log_model(instance): | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED |         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||||
| @ -119,7 +121,7 @@ class AuditMiddleware: | |||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
|     def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_): |     def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_): | ||||||
|         """Signal handler for all object's pre_delete""" |         """Signal handler for all object's pre_delete""" | ||||||
|         if isinstance(instance, IGNORED_MODELS):  # pragma: no cover |         if not should_log_model(instance):  # pragma: no cover | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         EventNewThread( |         EventNewThread( | ||||||
|  | |||||||
| @ -28,126 +28,6 @@ def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | |||||||
|         event.save() |         event.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): |  | ||||||
|     db_alias = schema_editor.connection.alias |  | ||||||
|     Group = apps.get_model("authentik_core", "Group") |  | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |  | ||||||
|     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") |  | ||||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") |  | ||||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") |  | ||||||
|  |  | ||||||
|     admin_group = ( |  | ||||||
|         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-match-configuration-error", |  | ||||||
|         defaults={"action": EventAction.CONFIGURATION_ERROR}, |  | ||||||
|     ) |  | ||||||
|     trigger, _ = NotificationRule.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-notify-configuration-error", |  | ||||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, |  | ||||||
|     ) |  | ||||||
|     trigger.transports.set( |  | ||||||
|         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") |  | ||||||
|     ) |  | ||||||
|     trigger.save() |  | ||||||
|     PolicyBinding.objects.using(db_alias).update_or_create( |  | ||||||
|         target=trigger, |  | ||||||
|         policy=policy, |  | ||||||
|         defaults={ |  | ||||||
|             "order": 0, |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): |  | ||||||
|     db_alias = schema_editor.connection.alias |  | ||||||
|     Group = apps.get_model("authentik_core", "Group") |  | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |  | ||||||
|     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") |  | ||||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") |  | ||||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") |  | ||||||
|  |  | ||||||
|     admin_group = ( |  | ||||||
|         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-match-update", |  | ||||||
|         defaults={"action": EventAction.UPDATE_AVAILABLE}, |  | ||||||
|     ) |  | ||||||
|     trigger, _ = NotificationRule.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-notify-update", |  | ||||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, |  | ||||||
|     ) |  | ||||||
|     trigger.transports.set( |  | ||||||
|         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") |  | ||||||
|     ) |  | ||||||
|     trigger.save() |  | ||||||
|     PolicyBinding.objects.using(db_alias).update_or_create( |  | ||||||
|         target=trigger, |  | ||||||
|         policy=policy, |  | ||||||
|         defaults={ |  | ||||||
|             "order": 0, |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): |  | ||||||
|     db_alias = schema_editor.connection.alias |  | ||||||
|     Group = apps.get_model("authentik_core", "Group") |  | ||||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") |  | ||||||
|     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") |  | ||||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") |  | ||||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") |  | ||||||
|  |  | ||||||
|     admin_group = ( |  | ||||||
|         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-match-policy-exception", |  | ||||||
|         defaults={"action": EventAction.POLICY_EXCEPTION}, |  | ||||||
|     ) |  | ||||||
|     policy_pm_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-match-property-mapping-exception", |  | ||||||
|         defaults={"action": EventAction.PROPERTY_MAPPING_EXCEPTION}, |  | ||||||
|     ) |  | ||||||
|     trigger, _ = NotificationRule.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-notify-exception", |  | ||||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, |  | ||||||
|     ) |  | ||||||
|     trigger.transports.set( |  | ||||||
|         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") |  | ||||||
|     ) |  | ||||||
|     trigger.save() |  | ||||||
|     PolicyBinding.objects.using(db_alias).update_or_create( |  | ||||||
|         target=trigger, |  | ||||||
|         policy=policy_policy_exc, |  | ||||||
|         defaults={ |  | ||||||
|             "order": 0, |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     PolicyBinding.objects.using(db_alias).update_or_create( |  | ||||||
|         target=trigger, |  | ||||||
|         policy=policy_pm_exc, |  | ||||||
|         defaults={ |  | ||||||
|             "order": 1, |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def transport_email_global(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): |  | ||||||
|     db_alias = schema_editor.connection.alias |  | ||||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") |  | ||||||
|  |  | ||||||
|     NotificationTransport.objects.using(db_alias).update_or_create( |  | ||||||
|         name="default-email-transport", |  | ||||||
|         defaults={"mode": TransportMode.EMAIL}, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def token_view_to_secret_view(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | def token_view_to_secret_view(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|     from authentik.events.models import EventAction |     from authentik.events.models import EventAction | ||||||
|  |  | ||||||
| @ -432,18 +312,6 @@ class Migration(migrations.Migration): | |||||||
|                 "verbose_name_plural": "Notifications", |                 "verbose_name_plural": "Notifications", | ||||||
|             }, |             }, | ||||||
|         ), |         ), | ||||||
|         migrations.RunPython( |  | ||||||
|             code=transport_email_global, |  | ||||||
|         ), |  | ||||||
|         migrations.RunPython( |  | ||||||
|             code=notify_configuration_error, |  | ||||||
|         ), |  | ||||||
|         migrations.RunPython( |  | ||||||
|             code=notify_update, |  | ||||||
|         ), |  | ||||||
|         migrations.RunPython( |  | ||||||
|             code=notify_exception, |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |         migrations.AddField( | ||||||
|             model_name="notificationtransport", |             model_name="notificationtransport", | ||||||
|             name="send_once", |             name="send_once", | ||||||
|  | |||||||
| @ -22,14 +22,20 @@ from django.utils.translation import gettext as _ | |||||||
| from requests import RequestException | from requests import RequestException | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik import __version__ | from authentik import get_full_version | ||||||
| from authentik.core.middleware import ( | from authentik.core.middleware import ( | ||||||
|     SESSION_KEY_IMPERSONATE_ORIGINAL_USER, |     SESSION_KEY_IMPERSONATE_ORIGINAL_USER, | ||||||
|     SESSION_KEY_IMPERSONATE_USER, |     SESSION_KEY_IMPERSONATE_USER, | ||||||
| ) | ) | ||||||
| from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | ||||||
| from authentik.events.geo import GEOIP_READER | from authentik.events.geo import GEOIP_READER | ||||||
| from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict | from authentik.events.utils import ( | ||||||
|  |     cleanse_dict, | ||||||
|  |     get_user, | ||||||
|  |     model_to_dict, | ||||||
|  |     sanitize_dict, | ||||||
|  |     sanitize_item, | ||||||
|  | ) | ||||||
| from authentik.lib.models import DomainlessURLValidator, SerializerModel | from authentik.lib.models import DomainlessURLValidator, SerializerModel | ||||||
| from authentik.lib.sentry import SentryIgnoredException | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.lib.utils.http import get_client_ip, get_http_session | from authentik.lib.utils.http import get_client_ip, get_http_session | ||||||
| @ -355,11 +361,13 @@ class NotificationTransport(SerializerModel): | |||||||
|             "user_username": notification.user.username, |             "user_username": notification.user.username, | ||||||
|         } |         } | ||||||
|         if self.webhook_mapping: |         if self.webhook_mapping: | ||||||
|             default_body = self.webhook_mapping.evaluate( |             default_body = sanitize_item( | ||||||
|  |                 self.webhook_mapping.evaluate( | ||||||
|                     user=notification.user, |                     user=notification.user, | ||||||
|                     request=None, |                     request=None, | ||||||
|                     notification=notification, |                     notification=notification, | ||||||
|                 ) |                 ) | ||||||
|  |             ) | ||||||
|         try: |         try: | ||||||
|             response = get_http_session().post( |             response = get_http_session().post( | ||||||
|                 self.webhook_url, |                 self.webhook_url, | ||||||
| @ -406,7 +414,7 @@ class NotificationTransport(SerializerModel): | |||||||
|                     "title": notification.body, |                     "title": notification.body, | ||||||
|                     "color": "#fd4b2d", |                     "color": "#fd4b2d", | ||||||
|                     "fields": fields, |                     "fields": fields, | ||||||
|                     "footer": f"authentik v{__version__}", |                     "footer": f"authentik {get_full_version()}", | ||||||
|                 } |                 } | ||||||
|             ], |             ], | ||||||
|         } |         } | ||||||
|  | |||||||
| @ -134,11 +134,12 @@ class MonitoredTask(Task): | |||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |     # pylint: disable=too-many-arguments | ||||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): |     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||||
|         if self._result: |         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||||
|  |         if not self._result: | ||||||
|  |             return | ||||||
|         if not self._result.uid: |         if not self._result.uid: | ||||||
|             self._result.uid = self._uid |             self._result.uid = self._uid | ||||||
|             if self.save_on_success: |         info = TaskInfo( | ||||||
|                 TaskInfo( |  | ||||||
|             task_name=self.__name__, |             task_name=self.__name__, | ||||||
|             task_description=self.__doc__, |             task_description=self.__doc__, | ||||||
|             start_timestamp=self.start, |             start_timestamp=self.start, | ||||||
| @ -149,11 +150,15 @@ class MonitoredTask(Task): | |||||||
|             task_call_func=self.__name__, |             task_call_func=self.__name__, | ||||||
|             task_call_args=args, |             task_call_args=args, | ||||||
|             task_call_kwargs=kwargs, |             task_call_kwargs=kwargs, | ||||||
|                 ).save(self.result_timeout_hours) |         ) | ||||||
|         return super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) |         if self._result.status == TaskResultStatus.SUCCESSFUL and not self.save_on_success: | ||||||
|  |             info.delete() | ||||||
|  |             return | ||||||
|  |         info.save(self.result_timeout_hours) | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |     # pylint: disable=too-many-arguments | ||||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): |     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||||
|  |         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||||
|         if not self._result: |         if not self._result: | ||||||
|             self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)]) |             self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)]) | ||||||
|         if not self._result.uid: |         if not self._result.uid: | ||||||
| @ -174,7 +179,6 @@ class MonitoredTask(Task): | |||||||
|             EventAction.SYSTEM_TASK_EXCEPTION, |             EventAction.SYSTEM_TASK_EXCEPTION, | ||||||
|             message=(f"Task {self.__name__} encountered an error: {exception_to_string(exc)}"), |             message=(f"Task {self.__name__} encountered an error: {exception_to_string(exc)}"), | ||||||
|         ).save() |         ).save() | ||||||
|         return super().on_failure(exc, task_id, args, kwargs, einfo=einfo) |  | ||||||
|  |  | ||||||
|     def run(self, *args, **kwargs): |     def run(self, *args, **kwargs): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  | |||||||
| @ -31,8 +31,8 @@ class TestEventsNotifications(TestCase): | |||||||
|  |  | ||||||
|     def test_trigger_empty(self): |     def test_trigger_empty(self): | ||||||
|         """Test trigger without any policies attached""" |         """Test trigger without any policies attached""" | ||||||
|         transport = NotificationTransport.objects.create(name="transport") |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|  |  | ||||||
| @ -43,8 +43,8 @@ class TestEventsNotifications(TestCase): | |||||||
|  |  | ||||||
|     def test_trigger_single(self): |     def test_trigger_single(self): | ||||||
|         """Test simple transport triggering""" |         """Test simple transport triggering""" | ||||||
|         transport = NotificationTransport.objects.create(name="transport") |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -59,7 +59,7 @@ class TestEventsNotifications(TestCase): | |||||||
|  |  | ||||||
|     def test_trigger_no_group(self): |     def test_trigger_no_group(self): | ||||||
|         """Test trigger without group""" |         """Test trigger without group""" | ||||||
|         trigger = NotificationRule.objects.create(name="trigger") |         trigger = NotificationRule.objects.create(name=generate_id()) | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX |             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||||
|         ) |         ) | ||||||
| @ -72,9 +72,9 @@ class TestEventsNotifications(TestCase): | |||||||
|  |  | ||||||
|     def test_policy_error_recursive(self): |     def test_policy_error_recursive(self): | ||||||
|         """Test Policy error which would cause recursion""" |         """Test Policy error which would cause recursion""" | ||||||
|         transport = NotificationTransport.objects.create(name="transport") |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -95,9 +95,9 @@ class TestEventsNotifications(TestCase): | |||||||
|         self.group.users.add(user2) |         self.group.users.add(user2) | ||||||
|         self.group.save() |         self.group.save() | ||||||
|  |  | ||||||
|         transport = NotificationTransport.objects.create(name="transport", send_once=True) |         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -118,10 +118,10 @@ class TestEventsNotifications(TestCase): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         transport = NotificationTransport.objects.create( |         transport = NotificationTransport.objects.create( | ||||||
|             name="transport", webhook_mapping=mapping, mode=TransportMode.LOCAL |             name=generate_id(), webhook_mapping=mapping, mode=TransportMode.LOCAL | ||||||
|         ) |         ) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX |             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||||
|  | |||||||
							
								
								
									
										131
									
								
								authentik/events/tests/test_transports.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								authentik/events/tests/test_transports.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,131 @@ | |||||||
|  | """transport tests""" | ||||||
|  | from unittest.mock import PropertyMock, patch | ||||||
|  |  | ||||||
|  | from django.core import mail | ||||||
|  | from django.core.mail.backends.locmem import EmailBackend | ||||||
|  | from django.test import TestCase | ||||||
|  | from requests_mock import Mocker | ||||||
|  |  | ||||||
|  | from authentik import get_full_version | ||||||
|  | from authentik.core.tests.utils import create_test_admin_user | ||||||
|  | from authentik.events.models import ( | ||||||
|  |     Event, | ||||||
|  |     Notification, | ||||||
|  |     NotificationSeverity, | ||||||
|  |     NotificationTransport, | ||||||
|  |     NotificationWebhookMapping, | ||||||
|  |     TransportMode, | ||||||
|  | ) | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestEventTransports(TestCase): | ||||||
|  |     """Test Event Transports""" | ||||||
|  |  | ||||||
|  |     def setUp(self) -> None: | ||||||
|  |         self.user = create_test_admin_user() | ||||||
|  |         self.event = Event.new("foo", "testing", foo="bar,").set_user(self.user) | ||||||
|  |         self.event.save() | ||||||
|  |         self.notification = Notification.objects.create( | ||||||
|  |             severity=NotificationSeverity.ALERT, | ||||||
|  |             body="foo", | ||||||
|  |             event=self.event, | ||||||
|  |             user=self.user, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_transport_webhook(self): | ||||||
|  |         """Test webhook transport""" | ||||||
|  |         transport: NotificationTransport = NotificationTransport.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             mode=TransportMode.WEBHOOK, | ||||||
|  |             webhook_url="http://localhost:1234/test", | ||||||
|  |         ) | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.post("http://localhost:1234/test") | ||||||
|  |             transport.send(self.notification) | ||||||
|  |             self.assertEqual(mocker.call_count, 1) | ||||||
|  |             self.assertEqual(mocker.request_history[0].method, "POST") | ||||||
|  |             self.assertJSONEqual( | ||||||
|  |                 mocker.request_history[0].body.decode(), | ||||||
|  |                 { | ||||||
|  |                     "body": "foo", | ||||||
|  |                     "severity": "alert", | ||||||
|  |                     "user_email": self.user.email, | ||||||
|  |                     "user_username": self.user.username, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_transport_webhook_mapping(self): | ||||||
|  |         """Test webhook transport with custom mapping""" | ||||||
|  |         mapping = NotificationWebhookMapping.objects.create( | ||||||
|  |             name=generate_id(), expression="return request.user" | ||||||
|  |         ) | ||||||
|  |         transport: NotificationTransport = NotificationTransport.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             mode=TransportMode.WEBHOOK, | ||||||
|  |             webhook_url="http://localhost:1234/test", | ||||||
|  |             webhook_mapping=mapping, | ||||||
|  |         ) | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.post("http://localhost:1234/test") | ||||||
|  |             transport.send(self.notification) | ||||||
|  |             self.assertEqual(mocker.call_count, 1) | ||||||
|  |             self.assertEqual(mocker.request_history[0].method, "POST") | ||||||
|  |             self.assertJSONEqual( | ||||||
|  |                 mocker.request_history[0].body.decode(), | ||||||
|  |                 {"email": self.user.email, "pk": self.user.pk, "username": self.user.username}, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_transport_webhook_slack(self): | ||||||
|  |         """Test webhook transport (slack)""" | ||||||
|  |         transport: NotificationTransport = NotificationTransport.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             mode=TransportMode.WEBHOOK_SLACK, | ||||||
|  |             webhook_url="http://localhost:1234/test", | ||||||
|  |         ) | ||||||
|  |         with Mocker() as mocker: | ||||||
|  |             mocker.post("http://localhost:1234/test") | ||||||
|  |             transport.send(self.notification) | ||||||
|  |             self.assertEqual(mocker.call_count, 1) | ||||||
|  |             self.assertEqual(mocker.request_history[0].method, "POST") | ||||||
|  |             self.assertJSONEqual( | ||||||
|  |                 mocker.request_history[0].body.decode(), | ||||||
|  |                 { | ||||||
|  |                     "username": "authentik", | ||||||
|  |                     "icon_url": "https://goauthentik.io/img/icon.png", | ||||||
|  |                     "attachments": [ | ||||||
|  |                         { | ||||||
|  |                             "author_name": "authentik", | ||||||
|  |                             "author_link": "https://goauthentik.io", | ||||||
|  |                             "author_icon": "https://goauthentik.io/img/icon.png", | ||||||
|  |                             "title": "custom_foo", | ||||||
|  |                             "color": "#fd4b2d", | ||||||
|  |                             "fields": [ | ||||||
|  |                                 {"title": "Severity", "value": "alert", "short": True}, | ||||||
|  |                                 { | ||||||
|  |                                     "title": "Dispatched for user", | ||||||
|  |                                     "value": self.user.username, | ||||||
|  |                                     "short": True, | ||||||
|  |                                 }, | ||||||
|  |                                 {"title": "foo", "value": "bar,"}, | ||||||
|  |                             ], | ||||||
|  |                             "footer": f"authentik {get_full_version()}", | ||||||
|  |                         } | ||||||
|  |                     ], | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_transport_email(self): | ||||||
|  |         """Test email transport""" | ||||||
|  |         transport: NotificationTransport = NotificationTransport.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             mode=TransportMode.EMAIL, | ||||||
|  |         ) | ||||||
|  |         with patch( | ||||||
|  |             "authentik.stages.email.models.EmailStage.backend_class", | ||||||
|  |             PropertyMock(return_value=EmailBackend), | ||||||
|  |         ): | ||||||
|  |             transport.send(self.notification) | ||||||
|  |             self.assertEqual(len(mail.outbox), 1) | ||||||
|  |             self.assertEqual(mail.outbox[0].subject, "authentik Notification: custom_foo") | ||||||
|  |             self.assertIn(self.notification.body, mail.outbox[0].alternatives[0][0]) | ||||||
| @ -23,21 +23,31 @@ from authentik.policies.types import PolicyRequest | |||||||
| ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I) | ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I) | ||||||
|  |  | ||||||
|  |  | ||||||
| def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]: | def cleanse_item(key: str, value: Any) -> Any: | ||||||
|     """Cleanse a dictionary, recursively""" |     """Cleanse a single item""" | ||||||
|     final_dict = {} |     if isinstance(value, dict): | ||||||
|     for key, value in source.items(): |         return cleanse_dict(value) | ||||||
|  |     if isinstance(value, list): | ||||||
|  |         for idx, item in enumerate(value): | ||||||
|  |             value[idx] = cleanse_item(key, item) | ||||||
|  |         return value | ||||||
|     try: |     try: | ||||||
|         if SafeExceptionReporterFilter.hidden_settings.search( |         if SafeExceptionReporterFilter.hidden_settings.search( | ||||||
|             key |             key | ||||||
|         ) and not ALLOWED_SPECIAL_KEYS.search(key): |         ) and not ALLOWED_SPECIAL_KEYS.search(key): | ||||||
|                 final_dict[key] = SafeExceptionReporterFilter.cleansed_substitute |             return SafeExceptionReporterFilter.cleansed_substitute | ||||||
|             else: |  | ||||||
|                 final_dict[key] = value |  | ||||||
|     except TypeError:  # pragma: no cover |     except TypeError:  # pragma: no cover | ||||||
|             final_dict[key] = value |         return value | ||||||
|         if isinstance(value, dict): |     return value | ||||||
|             final_dict[key] = cleanse_dict(value) |  | ||||||
|  |  | ||||||
|  | def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]: | ||||||
|  |     """Cleanse a dictionary, recursively""" | ||||||
|  |     final_dict = {} | ||||||
|  |     for key, value in source.items(): | ||||||
|  |         new_value = cleanse_item(key, value) | ||||||
|  |         if new_value is not ...: | ||||||
|  |             final_dict[key] = new_value | ||||||
|     return final_dict |     return final_dict | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -70,6 +80,45 @@ def get_user(user: User, original_user: Optional[User] = None) -> dict[str, Any] | |||||||
|     return user_data |     return user_data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # pylint: disable=too-many-return-statements | ||||||
|  | def sanitize_item(value: Any) -> Any: | ||||||
|  |     """Sanitize a single item, ensure it is JSON parsable""" | ||||||
|  |     if is_dataclass(value): | ||||||
|  |         # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, | ||||||
|  |         # and deepcopy doesn't work with HttpRequests (neither django nor rest_framework). | ||||||
|  |         # Currently, the only dataclass that actually holds an http request is a PolicyRequest | ||||||
|  |         if isinstance(value, PolicyRequest): | ||||||
|  |             value.http_request = None | ||||||
|  |         value = asdict(value) | ||||||
|  |     if isinstance(value, dict): | ||||||
|  |         return sanitize_dict(value) | ||||||
|  |     if isinstance(value, list): | ||||||
|  |         new_values = [] | ||||||
|  |         for item in value: | ||||||
|  |             new_value = sanitize_item(item) | ||||||
|  |             if new_value: | ||||||
|  |                 new_values.append(new_value) | ||||||
|  |         return new_values | ||||||
|  |     if isinstance(value, (User, AnonymousUser)): | ||||||
|  |         return sanitize_dict(get_user(value)) | ||||||
|  |     if isinstance(value, models.Model): | ||||||
|  |         return sanitize_dict(model_to_dict(value)) | ||||||
|  |     if isinstance(value, UUID): | ||||||
|  |         return value.hex | ||||||
|  |     if isinstance(value, (HttpRequest, WSGIRequest)): | ||||||
|  |         return ... | ||||||
|  |     if isinstance(value, City): | ||||||
|  |         return GEOIP_READER.city_to_dict(value) | ||||||
|  |     if isinstance(value, Path): | ||||||
|  |         return str(value) | ||||||
|  |     if isinstance(value, type): | ||||||
|  |         return { | ||||||
|  |             "type": value.__name__, | ||||||
|  |             "module": value.__module__, | ||||||
|  |         } | ||||||
|  |     return value | ||||||
|  |  | ||||||
|  |  | ||||||
| def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]: | def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]: | ||||||
|     """clean source of all Models that would interfere with the JSONField. |     """clean source of all Models that would interfere with the JSONField. | ||||||
|     Models are replaced with a dictionary of { |     Models are replaced with a dictionary of { | ||||||
| @ -79,32 +128,7 @@ def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]: | |||||||
|     }""" |     }""" | ||||||
|     final_dict = {} |     final_dict = {} | ||||||
|     for key, value in source.items(): |     for key, value in source.items(): | ||||||
|         if is_dataclass(value): |         new_value = sanitize_item(value) | ||||||
|             # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, |         if new_value is not ...: | ||||||
|             # and deepcopy doesn't work with HttpRequests (neither django nor rest_framework). |             final_dict[key] = new_value | ||||||
|             # Currently, the only dataclass that actually holds an http request is a PolicyRequest |  | ||||||
|             if isinstance(value, PolicyRequest): |  | ||||||
|                 value.http_request = None |  | ||||||
|             value = asdict(value) |  | ||||||
|         if isinstance(value, dict): |  | ||||||
|             final_dict[key] = sanitize_dict(value) |  | ||||||
|         elif isinstance(value, (User, AnonymousUser)): |  | ||||||
|             final_dict[key] = sanitize_dict(get_user(value)) |  | ||||||
|         elif isinstance(value, models.Model): |  | ||||||
|             final_dict[key] = sanitize_dict(model_to_dict(value)) |  | ||||||
|         elif isinstance(value, UUID): |  | ||||||
|             final_dict[key] = value.hex |  | ||||||
|         elif isinstance(value, (HttpRequest, WSGIRequest)): |  | ||||||
|             continue |  | ||||||
|         elif isinstance(value, City): |  | ||||||
|             final_dict[key] = GEOIP_READER.city_to_dict(value) |  | ||||||
|         elif isinstance(value, Path): |  | ||||||
|             final_dict[key] = str(value) |  | ||||||
|         elif isinstance(value, type): |  | ||||||
|             final_dict[key] = { |  | ||||||
|                 "type": value.__name__, |  | ||||||
|                 "module": value.__module__, |  | ||||||
|             } |  | ||||||
|         else: |  | ||||||
|             final_dict[key] = value |  | ||||||
|     return final_dict |     return final_dict | ||||||
|  | |||||||
| @ -1,26 +1,22 @@ | |||||||
| """Flow API Views""" | """Flow API Views""" | ||||||
| from dataclasses import dataclass |  | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import Model |  | ||||||
| from django.http import HttpResponse | from django.http import HttpResponse | ||||||
| from django.http.response import HttpResponseBadRequest | from django.http.response import HttpResponseBadRequest | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
| from guardian.shortcuts import get_objects_for_user |  | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import ReadOnlyField | from rest_framework.fields import ReadOnlyField | ||||||
| from rest_framework.parsers import MultiPartParser | from rest_framework.parsers import MultiPartParser | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.serializers import CharField, ModelSerializer, Serializer, SerializerMethodField | from rest_framework.serializers import ModelSerializer, SerializerMethodField | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required | from authentik.api.decorators import permission_required | ||||||
| from authentik.blueprints.v1.exporter import Exporter | from authentik.blueprints.v1.exporter import FlowExporter | ||||||
| from authentik.blueprints.v1.importer import Importer | from authentik.blueprints.v1.importer import Importer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import ( | from authentik.core.api.utils import ( | ||||||
| @ -29,6 +25,7 @@ from authentik.core.api.utils import ( | |||||||
|     FileUploadSerializer, |     FileUploadSerializer, | ||||||
|     LinkSerializer, |     LinkSerializer, | ||||||
| ) | ) | ||||||
|  | from authentik.flows.api.flows_diagram import FlowDiagram, FlowDiagramSerializer | ||||||
| from authentik.flows.exceptions import FlowNonApplicableException | from authentik.flows.exceptions import FlowNonApplicableException | ||||||
| from authentik.flows.models import Flow | from authentik.flows.models import Flow | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key | from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key | ||||||
| @ -80,30 +77,6 @@ class FlowSerializer(ModelSerializer): | |||||||
|         } |         } | ||||||
|  |  | ||||||
|  |  | ||||||
| class FlowDiagramSerializer(Serializer): |  | ||||||
|     """response of the flow's diagram action""" |  | ||||||
|  |  | ||||||
|     diagram = CharField(read_only=True) |  | ||||||
|  |  | ||||||
|     def create(self, validated_data: dict) -> Model: |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def update(self, instance: Model, validated_data: dict) -> Model: |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass |  | ||||||
| class DiagramElement: |  | ||||||
|     """Single element used in a diagram""" |  | ||||||
|  |  | ||||||
|     identifier: str |  | ||||||
|     type: str |  | ||||||
|     rest: str |  | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |  | ||||||
|         return f"{self.identifier}=>{self.type}: {self.rest}" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class FlowViewSet(UsedByMixin, ModelViewSet): | class FlowViewSet(UsedByMixin, ModelViewSet): | ||||||
|     """Flow Viewset""" |     """Flow Viewset""" | ||||||
|  |  | ||||||
| @ -198,7 +171,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|     def export(self, request: Request, slug: str) -> Response: |     def export(self, request: Request, slug: str) -> Response: | ||||||
|         """Export flow to .yaml file""" |         """Export flow to .yaml file""" | ||||||
|         flow = self.get_object() |         flow = self.get_object() | ||||||
|         exporter = Exporter(flow) |         exporter = FlowExporter(flow) | ||||||
|         response = HttpResponse(content=exporter.export_to_string()) |         response = HttpResponse(content=exporter.export_to_string()) | ||||||
|         response["Content-Disposition"] = f'attachment; filename="{flow.slug}.yaml"' |         response["Content-Disposition"] = f'attachment; filename="{flow.slug}.yaml"' | ||||||
|         return response |         return response | ||||||
| @ -208,84 +181,9 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
|     def diagram(self, request: Request, slug: str) -> Response: |     def diagram(self, request: Request, slug: str) -> Response: | ||||||
|         """Return diagram for flow with slug `slug`, in the format used by flowchart.js""" |         """Return diagram for flow with slug `slug`, in the format used by flowchart.js""" | ||||||
|         flow = self.get_object() |         diagram = FlowDiagram(self.get_object(), request.user) | ||||||
|         header = [ |         output = diagram.build() | ||||||
|             DiagramElement("st", "start", "Start"), |         return Response({"diagram": output}) | ||||||
|         ] |  | ||||||
|         body: list[DiagramElement] = [] |  | ||||||
|         footer = [] |  | ||||||
|         # Collect all elements we need |  | ||||||
|         # First, policies bound to the flow itself |  | ||||||
|         for p_index, policy_binding in enumerate( |  | ||||||
|             get_objects_for_user(request.user, "authentik_policies.view_policybinding") |  | ||||||
|             .filter(target=flow) |  | ||||||
|             .exclude(policy__isnull=True) |  | ||||||
|             .order_by("order") |  | ||||||
|         ): |  | ||||||
|             body.append( |  | ||||||
|                 DiagramElement( |  | ||||||
|                     f"flow_policy_{p_index}", |  | ||||||
|                     "condition", |  | ||||||
|                     _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) |  | ||||||
|                     + "\n" |  | ||||||
|                     + policy_binding.policy.name, |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         # Collect all stages |  | ||||||
|         for s_index, stage_binding in enumerate( |  | ||||||
|             get_objects_for_user(request.user, "authentik_flows.view_flowstagebinding") |  | ||||||
|             .filter(target=flow) |  | ||||||
|             .order_by("order") |  | ||||||
|         ): |  | ||||||
|             # First all policies bound to stages since they execute before stages |  | ||||||
|             for p_index, policy_binding in enumerate( |  | ||||||
|                 get_objects_for_user(request.user, "authentik_policies.view_policybinding") |  | ||||||
|                 .filter(target=stage_binding) |  | ||||||
|                 .exclude(policy__isnull=True) |  | ||||||
|                 .order_by("order") |  | ||||||
|             ): |  | ||||||
|                 body.append( |  | ||||||
|                     DiagramElement( |  | ||||||
|                         f"stage_{s_index}_policy_{p_index}", |  | ||||||
|                         "condition", |  | ||||||
|                         _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) |  | ||||||
|                         + "\n" |  | ||||||
|                         + policy_binding.policy.name, |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             body.append( |  | ||||||
|                 DiagramElement( |  | ||||||
|                     f"stage_{s_index}", |  | ||||||
|                     "operation", |  | ||||||
|                     _("Stage (%(type)s)" % {"type": stage_binding.stage._meta.verbose_name}) |  | ||||||
|                     + "\n" |  | ||||||
|                     + stage_binding.stage.name, |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         # If the 2nd last element is a policy, we need to have an item to point to |  | ||||||
|         # for a negative case |  | ||||||
|         body.append( |  | ||||||
|             DiagramElement("e", "end", "End|future"), |  | ||||||
|         ) |  | ||||||
|         if len(body) == 1: |  | ||||||
|             footer.append("st(right)->e") |  | ||||||
|         else: |  | ||||||
|             # Actual diagram flow |  | ||||||
|             footer.append(f"st(right)->{body[0].identifier}") |  | ||||||
|             for index in range(len(body) - 1): |  | ||||||
|                 element: DiagramElement = body[index] |  | ||||||
|                 if element.type == "condition": |  | ||||||
|                     # Policy passes, link policy yes to next stage |  | ||||||
|                     footer.append(f"{element.identifier}(yes, right)->{body[index + 1].identifier}") |  | ||||||
|                     # Policy doesn't pass, go to stage after next stage |  | ||||||
|                     no_element = body[index + 1] |  | ||||||
|                     if no_element.type != "end": |  | ||||||
|                         no_element = body[index + 2] |  | ||||||
|                     footer.append(f"{element.identifier}(no, bottom)->{no_element.identifier}") |  | ||||||
|                 elif element.type == "operation": |  | ||||||
|                     footer.append(f"{element.identifier}(bottom)->{body[index + 1].identifier}") |  | ||||||
|         diagram = "\n".join([str(x) for x in header + body + footer]) |  | ||||||
|         return Response({"diagram": diagram}) |  | ||||||
|  |  | ||||||
|     @permission_required("authentik_flows.change_flow") |     @permission_required("authentik_flows.change_flow") | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|  | |||||||
							
								
								
									
										206
									
								
								authentik/flows/api/flows_diagram.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								authentik/flows/api/flows_diagram.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,206 @@ | |||||||
|  | """Flows Diagram API""" | ||||||
|  | from dataclasses import dataclass, field | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | from django.utils.translation import gettext as _ | ||||||
|  | from guardian.shortcuts import get_objects_for_user | ||||||
|  | from rest_framework.serializers import CharField | ||||||
|  |  | ||||||
|  | from authentik.core.api.utils import PassiveSerializer | ||||||
|  | from authentik.core.models import User | ||||||
|  | from authentik.flows.models import Flow, FlowStageBinding | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class DiagramElement: | ||||||
|  |     """Single element used in a diagram""" | ||||||
|  |  | ||||||
|  |     identifier: str | ||||||
|  |     description: str | ||||||
|  |     action: Optional[str] = None | ||||||
|  |     source: Optional[list["DiagramElement"]] = None | ||||||
|  |  | ||||||
|  |     style: list[str] = field(default_factory=lambda: ["[", "]"]) | ||||||
|  |  | ||||||
|  |     def __str__(self) -> str: | ||||||
|  |         element = f'{self.identifier}{self.style[0]}"{self.description}"{self.style[1]}' | ||||||
|  |         if self.action is not None: | ||||||
|  |             if self.action != "": | ||||||
|  |                 element = f"--{self.action}--> {element}" | ||||||
|  |             else: | ||||||
|  |                 element = f"--> {element}" | ||||||
|  |         if self.source: | ||||||
|  |             source_element = [] | ||||||
|  |             for source in self.source: | ||||||
|  |                 source_element.append(f"{source.identifier} {element}") | ||||||
|  |             return "\n".join(source_element) | ||||||
|  |         return element | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FlowDiagramSerializer(PassiveSerializer): | ||||||
|  |     """response of the flow's diagram action""" | ||||||
|  |  | ||||||
|  |     diagram = CharField(read_only=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FlowDiagram: | ||||||
|  |     """Generate flow chart fow a flow""" | ||||||
|  |  | ||||||
|  |     flow: Flow | ||||||
|  |     user: User | ||||||
|  |  | ||||||
|  |     def __init__(self, flow: Flow, user: User) -> None: | ||||||
|  |         self.flow = flow | ||||||
|  |         self.user = user | ||||||
|  |  | ||||||
|  |     def get_flow_policies(self, parent_elements: list[DiagramElement]) -> list[DiagramElement]: | ||||||
|  |         """Collect all policies bound to the flow""" | ||||||
|  |         elements = [] | ||||||
|  |         for p_index, policy_binding in enumerate( | ||||||
|  |             get_objects_for_user(self.user, "authentik_policies.view_policybinding") | ||||||
|  |             .filter(target=self.flow) | ||||||
|  |             .exclude(policy__isnull=True) | ||||||
|  |             .order_by("order") | ||||||
|  |         ): | ||||||
|  |             element = DiagramElement( | ||||||
|  |                 f"flow_policy_{p_index}", | ||||||
|  |                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||||
|  |                 + "\n" | ||||||
|  |                 + policy_binding.policy.name, | ||||||
|  |                 _("Binding %(order)d" % {"order": policy_binding.order}), | ||||||
|  |                 parent_elements, | ||||||
|  |                 style=["{{", "}}"], | ||||||
|  |             ) | ||||||
|  |             elements.append(element) | ||||||
|  |         return elements | ||||||
|  |  | ||||||
|  |     def get_stage_policies( | ||||||
|  |         self, | ||||||
|  |         stage_index: int, | ||||||
|  |         stage_binding: FlowStageBinding, | ||||||
|  |         parent_elements: list[DiagramElement], | ||||||
|  |     ) -> list[DiagramElement]: | ||||||
|  |         """First all policies bound to stages since they execute before stages""" | ||||||
|  |         elements = [] | ||||||
|  |         for p_index, policy_binding in enumerate( | ||||||
|  |             get_objects_for_user(self.user, "authentik_policies.view_policybinding") | ||||||
|  |             .filter(target=stage_binding) | ||||||
|  |             .exclude(policy__isnull=True) | ||||||
|  |             .order_by("order") | ||||||
|  |         ): | ||||||
|  |             element = DiagramElement( | ||||||
|  |                 f"stage_{stage_index}_policy_{p_index}", | ||||||
|  |                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||||
|  |                 + "\n" | ||||||
|  |                 + policy_binding.policy.name, | ||||||
|  |                 "", | ||||||
|  |                 parent_elements, | ||||||
|  |                 style=["{{", "}}"], | ||||||
|  |             ) | ||||||
|  |             elements.append(element) | ||||||
|  |         return elements | ||||||
|  |  | ||||||
|  |     def get_stages(self, parent_elements: list[DiagramElement]) -> list[str | DiagramElement]: | ||||||
|  |         """Collect all stages""" | ||||||
|  |         elements = [] | ||||||
|  |         stages = [] | ||||||
|  |         for s_index, stage_binding in enumerate( | ||||||
|  |             get_objects_for_user(self.user, "authentik_flows.view_flowstagebinding") | ||||||
|  |             .filter(target=self.flow) | ||||||
|  |             .order_by("order") | ||||||
|  |         ): | ||||||
|  |             stage_policies = self.get_stage_policies(s_index, stage_binding, parent_elements) | ||||||
|  |             elements.extend(stage_policies) | ||||||
|  |  | ||||||
|  |             action = "" | ||||||
|  |             if len(stage_policies) > 0: | ||||||
|  |                 action = _("Policy passed") | ||||||
|  |  | ||||||
|  |             element = DiagramElement( | ||||||
|  |                 f"stage_{s_index}", | ||||||
|  |                 _("Stage (%(type)s)" % {"type": stage_binding.stage._meta.verbose_name}) | ||||||
|  |                 + "\n" | ||||||
|  |                 + stage_binding.stage.name, | ||||||
|  |                 action, | ||||||
|  |                 stage_policies, | ||||||
|  |                 style=["([", "])"], | ||||||
|  |             ) | ||||||
|  |             stages.append(element) | ||||||
|  |  | ||||||
|  |             parent_elements = [element] | ||||||
|  |  | ||||||
|  |             # This adds connections for policy denies, but retroactively, as we can't really | ||||||
|  |             # look ahead | ||||||
|  |             # Check if we have a stage behind us and if it has any sources | ||||||
|  |             if s_index > 0: | ||||||
|  |                 last_stage: DiagramElement = stages[s_index - 1] | ||||||
|  |                 if last_stage.source and len(last_stage.source) > 0: | ||||||
|  |                     # If it has any sources, add a connection from each of that stage's sources | ||||||
|  |                     # to this stage | ||||||
|  |                     for source in last_stage.source: | ||||||
|  |                         elements.append( | ||||||
|  |                             DiagramElement( | ||||||
|  |                                 element.identifier, | ||||||
|  |                                 element.description, | ||||||
|  |                                 _("Policy denied"), | ||||||
|  |                                 [source], | ||||||
|  |                                 style=element.style, | ||||||
|  |                             ) | ||||||
|  |                         ) | ||||||
|  |  | ||||||
|  |         if len(stages) > 0: | ||||||
|  |             elements.append( | ||||||
|  |                 DiagramElement( | ||||||
|  |                     "done", | ||||||
|  |                     _("End of the flow"), | ||||||
|  |                     "", | ||||||
|  |                     [stages[-1]], | ||||||
|  |                     style=["[[", "]]"], | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|  |         return stages + elements | ||||||
|  |  | ||||||
|  |     def build(self) -> str: | ||||||
|  |         """Build flowchart""" | ||||||
|  |         all_elements = [ | ||||||
|  |             "graph TD", | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         pre_flow_policies_element = DiagramElement( | ||||||
|  |             "flow_pre", _("Pre-flow policies"), style=["[[", "]]"] | ||||||
|  |         ) | ||||||
|  |         flow_policies = self.get_flow_policies([pre_flow_policies_element]) | ||||||
|  |         if len(flow_policies) > 0: | ||||||
|  |             all_elements.append(pre_flow_policies_element) | ||||||
|  |             all_elements.extend(flow_policies) | ||||||
|  |             all_elements.append( | ||||||
|  |                 DiagramElement( | ||||||
|  |                     "done", | ||||||
|  |                     _("End of the flow"), | ||||||
|  |                     _("Policy denied"), | ||||||
|  |                     flow_policies, | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         flow_element = DiagramElement( | ||||||
|  |             "flow_start", | ||||||
|  |             _("Flow") + "\n" + self.flow.name, | ||||||
|  |             "" if len(flow_policies) > 0 else None, | ||||||
|  |             source=flow_policies, | ||||||
|  |             style=["[[", "]]"], | ||||||
|  |         ) | ||||||
|  |         all_elements.append(flow_element) | ||||||
|  |  | ||||||
|  |         stages = self.get_stages([flow_element]) | ||||||
|  |         all_elements.extend(stages) | ||||||
|  |         if len(stages) < 1: | ||||||
|  |             all_elements.append( | ||||||
|  |                 DiagramElement( | ||||||
|  |                     "done", | ||||||
|  |                     _("End of the flow"), | ||||||
|  |                     "", | ||||||
|  |                     [flow_element], | ||||||
|  |                     style=["[[", "]]"], | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|  |         return "\n".join([str(x) for x in all_elements]) | ||||||
| @ -1,7 +1,7 @@ | |||||||
| """authentik flows app config""" | """authentik flows app config""" | ||||||
| from prometheus_client import Gauge, Histogram | from prometheus_client import Gauge, Histogram | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.utils.reflection import all_subclasses | from authentik.lib.utils.reflection import all_subclasses | ||||||
|  |  | ||||||
| GAUGE_FLOWS_CACHED = Gauge( | GAUGE_FLOWS_CACHED = Gauge( | ||||||
| @ -28,7 +28,7 @@ class AuthentikFlowsConfig(ManagedAppConfig): | |||||||
|         """Load flows signals""" |         """Load flows signals""" | ||||||
|         self.import_module("authentik.flows.signals") |         self.import_module("authentik.flows.signals") | ||||||
|  |  | ||||||
|     def reconcile_stages_loaded(self): |     def reconcile_load_stages(self): | ||||||
|         """Ensure all stages are loaded""" |         """Ensure all stages are loaded""" | ||||||
|         from authentik.flows.models import Stage |         from authentik.flows.models import Stage | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,14 +1,14 @@ | |||||||
| """Challenge helpers""" | """Challenge helpers""" | ||||||
| from dataclasses import asdict, is_dataclass | from dataclasses import asdict, is_dataclass | ||||||
| from enum import Enum | from enum import Enum | ||||||
|  | from traceback import format_tb | ||||||
| from typing import TYPE_CHECKING, Optional, TypedDict | from typing import TYPE_CHECKING, Optional, TypedDict | ||||||
| from uuid import UUID | from uuid import UUID | ||||||
|  |  | ||||||
| from django.core.serializers.json import DjangoJSONEncoder | from django.core.serializers.json import DjangoJSONEncoder | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.http import JsonResponse | from django.http import JsonResponse | ||||||
| from rest_framework.fields import ChoiceField, DictField | from rest_framework.fields import CharField, ChoiceField, DictField | ||||||
| from rest_framework.serializers import CharField |  | ||||||
|  |  | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
|  |  | ||||||
| @ -90,6 +90,34 @@ class WithUserInfoChallenge(Challenge): | |||||||
|     pending_user_avatar = CharField() |     pending_user_avatar = CharField() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FlowErrorChallenge(WithUserInfoChallenge): | ||||||
|  |     """Challenge class when an unhandled error occurs during a stage. Normal users | ||||||
|  |     are shown an error message, superusers are shown a full stacktrace.""" | ||||||
|  |  | ||||||
|  |     component = CharField(default="xak-flow-error") | ||||||
|  |  | ||||||
|  |     request_id = CharField() | ||||||
|  |  | ||||||
|  |     error = CharField(required=False) | ||||||
|  |     traceback = CharField(required=False) | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         request = kwargs.pop("request", None) | ||||||
|  |         error = kwargs.pop("error", None) | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         if not request or not error: | ||||||
|  |             return | ||||||
|  |         self.request_id = request.request_id | ||||||
|  |         from authentik.core.models import USER_ATTRIBUTE_DEBUG | ||||||
|  |  | ||||||
|  |         if request.user and request.user.is_authenticated: | ||||||
|  |             if request.user.is_superuser or request.user.group_attributes(request).get( | ||||||
|  |                 USER_ATTRIBUTE_DEBUG, False | ||||||
|  |             ): | ||||||
|  |                 self.error = error | ||||||
|  |                 self.traceback = "".join(format_tb(self.error.__traceback__)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AccessDeniedChallenge(WithUserInfoChallenge): | class AccessDeniedChallenge(WithUserInfoChallenge): | ||||||
|     """Challenge when a flow's active stage calls `stage_invalid()`.""" |     """Challenge when a flow's active stage calls `stage_invalid()`.""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ class FlowPlanProcess(PROCESS_CLASS):  # pragma: no cover | |||||||
|  |  | ||||||
|     def run(self): |     def run(self): | ||||||
|         """Execute 1000 flow plans""" |         """Execute 1000 flow plans""" | ||||||
|         print(f"Proc {self.index} Running") |         LOGGER.info(f"Proc {self.index} Running") | ||||||
|  |  | ||||||
|         def test_inner(): |         def test_inner(): | ||||||
|             planner = FlowPlanner(self.flow) |             planner = FlowPlanner(self.flow) | ||||||
|  | |||||||
| @ -19,25 +19,6 @@ def update_flow_designation(apps: Apps, schema_editor: BaseDatabaseSchemaEditor) | |||||||
|             flow.save() |             flow.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| # First stage for default-source-enrollment flow (prompt stage) |  | ||||||
| # needs to have its policy re-evaluated |  | ||||||
| def update_default_source_enrollment_flow_binding( |  | ||||||
|     apps: Apps, schema_editor: BaseDatabaseSchemaEditor |  | ||||||
| ): |  | ||||||
|     Flow = apps.get_model("authentik_flows", "Flow") |  | ||||||
|     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") |  | ||||||
|     db_alias = schema_editor.connection.alias |  | ||||||
|  |  | ||||||
|     flows = Flow.objects.using(db_alias).filter(slug="default-source-enrollment") |  | ||||||
|     if not flows.exists(): |  | ||||||
|         return |  | ||||||
|     flow = flows.first() |  | ||||||
|  |  | ||||||
|     binding = FlowStageBinding.objects.get(target=flow, order=0) |  | ||||||
|     binding.re_evaluate_policies = True |  | ||||||
|     binding.save() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|     replaces = [ |     replaces = [ | ||||||
| @ -101,9 +82,6 @@ class Migration(migrations.Migration): | |||||||
|                 help_text="When this option is enabled, the planner will re-evaluate policies bound to this binding.", |                 help_text="When this option is enabled, the planner will re-evaluate policies bound to this binding.", | ||||||
|             ), |             ), | ||||||
|         ), |         ), | ||||||
|         migrations.RunPython( |  | ||||||
|             code=update_default_source_enrollment_flow_binding, |  | ||||||
|         ), |  | ||||||
|         migrations.AlterField( |         migrations.AlterField( | ||||||
|             model_name="flowstagebinding", |             model_name="flowstagebinding", | ||||||
|             name="re_evaluate_policies", |             name="re_evaluate_policies", | ||||||
|  | |||||||
| @ -1,22 +0,0 @@ | |||||||
| {% load i18n %} |  | ||||||
|  |  | ||||||
| <style> |  | ||||||
|     .ak-exception { |  | ||||||
|         font-family: monospace; |  | ||||||
|         overflow-x: scroll; |  | ||||||
|     } |  | ||||||
| </style> |  | ||||||
|  |  | ||||||
| <header class="pf-c-login__main-header"> |  | ||||||
|     <h1 class="pf-c-title pf-m-3xl"> |  | ||||||
|         {% trans 'Whoops!' %} |  | ||||||
|     </h1> |  | ||||||
| </header> |  | ||||||
| <div class="pf-c-login__main-body"> |  | ||||||
|     <h3> |  | ||||||
|         {% trans 'Something went wrong! Please try again later.' %} |  | ||||||
|     </h3> |  | ||||||
|     {% if debug %} |  | ||||||
|     <pre class="ak-exception">{{ tb }}{{ error }}</pre> |  | ||||||
|     {% endif %} |  | ||||||
| </div> |  | ||||||
| @ -9,22 +9,20 @@ from authentik.policies.dummy.models import DummyPolicy | |||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.stages.dummy.models import DummyStage | from authentik.stages.dummy.models import DummyStage | ||||||
|  |  | ||||||
| DIAGRAM_EXPECTED = """st=>start: Start | DIAGRAM_EXPECTED = """graph TD | ||||||
| stage_0=>operation: Stage (Dummy Stage) | flow_start[["Flow | ||||||
| dummy1 | test-default-context"]] | ||||||
| stage_1_policy_0=>condition: Policy (Dummy Policy) | --> stage_0(["Stage (Dummy Stage) | ||||||
| test | dummy1"]) | ||||||
| stage_1=>operation: Stage (Dummy Stage) | stage_1_policy_0 --Policy passed--> stage_1(["Stage (Dummy Stage) | ||||||
| dummy2 | dummy2"]) | ||||||
| e=>end: End|future | stage_0 --> stage_1_policy_0{{"Policy (Dummy Policy) | ||||||
| st(right)->stage_0 | dummy2-policy"}} | ||||||
| stage_0(bottom)->stage_1_policy_0 | stage_1 --> done[["End of the flow"]]""" | ||||||
| stage_1_policy_0(yes, right)->stage_1 | DIAGRAM_SHORT_EXPECTED = """graph TD | ||||||
| stage_1_policy_0(no, bottom)->e | flow_start[["Flow | ||||||
| stage_1(bottom)->e""" | test-default-context"]] | ||||||
| DIAGRAM_SHORT_EXPECTED = """st=>start: Start | flow_start --> done[["End of the flow"]]""" | ||||||
| e=>end: End|future |  | ||||||
| st(right)->e""" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestFlowsAPI(APITestCase): | class TestFlowsAPI(APITestCase): | ||||||
| @ -55,7 +53,9 @@ class TestFlowsAPI(APITestCase): | |||||||
|             slug="test-default-context", |             slug="test-default-context", | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |             designation=FlowDesignation.AUTHENTICATION, | ||||||
|         ) |         ) | ||||||
|         false_policy = DummyPolicy.objects.create(name="test", result=False, wait_min=1, wait_max=2) |         false_policy = DummyPolicy.objects.create( | ||||||
|  |             name="dummy2-policy", result=False, wait_min=1, wait_max=2 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         FlowStageBinding.objects.create( |         FlowStageBinding.objects.create( | ||||||
|             target=flow, stage=DummyStage.objects.create(name="dummy1"), order=0 |             target=flow, stage=DummyStage.objects.create(name="dummy1"), order=0 | ||||||
|  | |||||||
| @ -6,10 +6,9 @@ from django.test.client import RequestFactory | |||||||
| from django.urls.base import reverse | from django.urls.base import reverse | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
| from authentik.flows.challenge import ChallengeTypes | from authentik.flows.challenge import ChallengeTypes | ||||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, InvalidResponseAction | from authentik.flows.models import FlowDesignation, FlowStageBinding, InvalidResponseAction | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
| from authentik.stages.dummy.models import DummyStage | from authentik.stages.dummy.models import DummyStage | ||||||
| from authentik.stages.identification.models import IdentificationStage, UserFields | from authentik.stages.identification.models import IdentificationStage, UserFields | ||||||
|  |  | ||||||
| @ -24,11 +23,7 @@ class TestFlowInspector(APITestCase): | |||||||
|  |  | ||||||
|     def test(self): |     def test(self): | ||||||
|         """test inspector""" |         """test inspector""" | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow(FlowDesignation.AUTHENTICATION) | ||||||
|             name=generate_id(), |  | ||||||
|             slug=generate_id(), |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # Stage 1 is an identification stage |         # Stage 1 is an identification stage | ||||||
|         ident_stage = IdentificationStage.objects.create( |         ident_stage = IdentificationStage.objects.create( | ||||||
| @ -55,7 +50,7 @@ class TestFlowInspector(APITestCase): | |||||||
|                 "flow_info": { |                 "flow_info": { | ||||||
|                     "background": flow.background_url, |                     "background": flow.background_url, | ||||||
|                     "cancel_url": reverse("authentik_flows:cancel"), |                     "cancel_url": reverse("authentik_flows:cancel"), | ||||||
|                     "title": "", |                     "title": flow.title, | ||||||
|                     "layout": "stacked", |                     "layout": "stacked", | ||||||
|                 }, |                 }, | ||||||
|                 "type": ChallengeTypes.NATIVE.value, |                 "type": ChallengeTypes.NATIVE.value, | ||||||
|  | |||||||
| @ -8,9 +8,10 @@ from django.urls import reverse | |||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
|  | from authentik.core.tests.utils import create_test_flow | ||||||
| from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException | from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException | ||||||
| from authentik.flows.markers import ReevaluateMarker, StageMarker | from authentik.flows.markers import ReevaluateMarker, StageMarker | ||||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding | from authentik.flows.models import FlowDesignation, FlowStageBinding | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key | from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key | ||||||
| from authentik.lib.tests.utils import dummy_get_response | from authentik.lib.tests.utils import dummy_get_response | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| @ -32,11 +33,7 @@ class TestFlowPlanner(TestCase): | |||||||
|  |  | ||||||
|     def test_empty_plan(self): |     def test_empty_plan(self): | ||||||
|         """Test that empty plan raises exception""" |         """Test that empty plan raises exception""" | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow() | ||||||
|             name="test-empty", |  | ||||||
|             slug="test-empty", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         request = self.request_factory.get( |         request = self.request_factory.get( | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), | ||||||
|         ) |         ) | ||||||
| @ -52,11 +49,7 @@ class TestFlowPlanner(TestCase): | |||||||
|     ) |     ) | ||||||
|     def test_non_applicable_plan(self): |     def test_non_applicable_plan(self): | ||||||
|         """Test that empty plan raises exception""" |         """Test that empty plan raises exception""" | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow() | ||||||
|             name="test-empty", |  | ||||||
|             slug="test-empty", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         request = self.request_factory.get( |         request = self.request_factory.get( | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), | ||||||
|         ) |         ) | ||||||
| @ -69,11 +62,7 @@ class TestFlowPlanner(TestCase): | |||||||
|     @patch("authentik.flows.planner.cache", CACHE_MOCK) |     @patch("authentik.flows.planner.cache", CACHE_MOCK) | ||||||
|     def test_planner_cache(self): |     def test_planner_cache(self): | ||||||
|         """Test planner cache""" |         """Test planner cache""" | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow(FlowDesignation.AUTHENTICATION) | ||||||
|             name="test-cache", |  | ||||||
|             slug="test-cache", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         FlowStageBinding.objects.create( |         FlowStageBinding.objects.create( | ||||||
|             target=flow, stage=DummyStage.objects.create(name="dummy"), order=0 |             target=flow, stage=DummyStage.objects.create(name="dummy"), order=0 | ||||||
|         ) |         ) | ||||||
| @ -92,11 +81,7 @@ class TestFlowPlanner(TestCase): | |||||||
|  |  | ||||||
|     def test_planner_default_context(self): |     def test_planner_default_context(self): | ||||||
|         """Test planner with default_context""" |         """Test planner with default_context""" | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow() | ||||||
|             name="test-default-context", |  | ||||||
|             slug="test-default-context", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         FlowStageBinding.objects.create( |         FlowStageBinding.objects.create( | ||||||
|             target=flow, stage=DummyStage.objects.create(name="dummy"), order=0 |             target=flow, stage=DummyStage.objects.create(name="dummy"), order=0 | ||||||
|         ) |         ) | ||||||
| @ -113,11 +98,7 @@ class TestFlowPlanner(TestCase): | |||||||
|  |  | ||||||
|     def test_planner_marker_reevaluate(self): |     def test_planner_marker_reevaluate(self): | ||||||
|         """Test that the planner creates the proper marker""" |         """Test that the planner creates the proper marker""" | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow() | ||||||
|             name="test-default-context", |  | ||||||
|             slug="test-default-context", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         FlowStageBinding.objects.create( |         FlowStageBinding.objects.create( | ||||||
|             target=flow, |             target=flow, | ||||||
| @ -138,11 +119,7 @@ class TestFlowPlanner(TestCase): | |||||||
|  |  | ||||||
|     def test_planner_reevaluate_actual(self): |     def test_planner_reevaluate_actual(self): | ||||||
|         """Test planner with re-evaluate""" |         """Test planner with re-evaluate""" | ||||||
|         flow = Flow.objects.create( |         flow = create_test_flow() | ||||||
|             name="test-default-context", |  | ||||||
|             slug="test-default-context", |  | ||||||
|             designation=FlowDesignation.AUTHENTICATION, |  | ||||||
|         ) |  | ||||||
|         false_policy = DummyPolicy.objects.create(result=False, wait_min=1, wait_max=2) |         false_policy = DummyPolicy.objects.create(result=False, wait_min=1, wait_max=2) | ||||||
|  |  | ||||||
|         binding = FlowStageBinding.objects.create( |         binding = FlowStageBinding.objects.create( | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """authentik multi-stage authentication engine""" | """authentik multi-stage authentication engine""" | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from traceback import format_tb | from typing import Optional | ||||||
| from typing import Any, Optional |  | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.contrib.auth.mixins import LoginRequiredMixin | from django.contrib.auth.mixins import LoginRequiredMixin | ||||||
| @ -23,12 +22,12 @@ from sentry_sdk.api import set_tag | |||||||
| from sentry_sdk.hub import Hub | from sentry_sdk.hub import Hub | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import USER_ATTRIBUTE_DEBUG |  | ||||||
| from authentik.events.models import Event, EventAction, cleanse_dict | from authentik.events.models import Event, EventAction, cleanse_dict | ||||||
| from authentik.flows.challenge import ( | from authentik.flows.challenge import ( | ||||||
|     Challenge, |     Challenge, | ||||||
|     ChallengeResponse, |     ChallengeResponse, | ||||||
|     ChallengeTypes, |     ChallengeTypes, | ||||||
|  |     FlowErrorChallenge, | ||||||
|     HttpChallengeResponse, |     HttpChallengeResponse, | ||||||
|     RedirectChallenge, |     RedirectChallenge, | ||||||
|     ShellChallenge, |     ShellChallenge, | ||||||
| @ -153,6 +152,7 @@ class FlowExecutorView(APIView): | |||||||
|         token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first() |         token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first() | ||||||
|         if not token: |         if not token: | ||||||
|             return None |             return None | ||||||
|  |         plan = None | ||||||
|         try: |         try: | ||||||
|             plan = token.plan |             plan = token.plan | ||||||
|         except (AttributeError, EOFError, ImportError, IndexError) as exc: |         except (AttributeError, EOFError, ImportError, IndexError) as exc: | ||||||
| @ -253,7 +253,9 @@ class FlowExecutorView(APIView): | |||||||
|             action=EventAction.SYSTEM_EXCEPTION, |             action=EventAction.SYSTEM_EXCEPTION, | ||||||
|             message=exception_to_string(exc), |             message=exception_to_string(exc), | ||||||
|         ).from_http(self.request) |         ).from_http(self.request) | ||||||
|         return to_stage_response(self.request, FlowErrorResponse(self.request, exc)) |         return to_stage_response( | ||||||
|  |             self.request, HttpChallengeResponse(FlowErrorChallenge(self.request, exc)) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         responses={ |         responses={ | ||||||
| @ -440,30 +442,6 @@ class FlowExecutorView(APIView): | |||||||
|                 del self.request.session[key] |                 del self.request.session[key] | ||||||
|  |  | ||||||
|  |  | ||||||
| class FlowErrorResponse(TemplateResponse): |  | ||||||
|     """Response class when an unhandled error occurs during a stage. Normal users |  | ||||||
|     are shown an error message, superusers are shown a full stacktrace.""" |  | ||||||
|  |  | ||||||
|     error: Exception |  | ||||||
|  |  | ||||||
|     def __init__(self, request: HttpRequest, error: Exception) -> None: |  | ||||||
|         # For some reason pyright complains about keyword argument usage here |  | ||||||
|         # pyright: reportGeneralTypeIssues=false |  | ||||||
|         super().__init__(request=request, template="flows/error.html") |  | ||||||
|         self.error = error |  | ||||||
|  |  | ||||||
|     def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: |  | ||||||
|         if not context: |  | ||||||
|             context = {} |  | ||||||
|         context["error"] = self.error |  | ||||||
|         if self._request.user and self._request.user.is_authenticated: |  | ||||||
|             if self._request.user.is_superuser or self._request.user.group_attributes( |  | ||||||
|                 self._request |  | ||||||
|             ).get(USER_ATTRIBUTE_DEBUG, False): |  | ||||||
|                 context["tb"] = "".join(format_tb(self.error.__traceback__)) |  | ||||||
|         return context |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CancelView(View): | class CancelView(View): | ||||||
|     """View which canels the currently active plan""" |     """View which canels the currently active plan""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -20,7 +20,7 @@ ENV_PREFIX = "AUTHENTIK" | |||||||
| ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local") | ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local") | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_path_from_dict(root: dict, path: str, sep=".", default=None): | def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any: | ||||||
|     """Recursively walk through `root`, checking each part of `path` split by `sep`. |     """Recursively walk through `root`, checking each part of `path` split by `sep`. | ||||||
|     If at any point a dict does not exist, return default""" |     If at any point a dict does not exist, return default""" | ||||||
|     for comp in path.split(sep): |     for comp in path.split(sep): | ||||||
| @ -62,7 +62,7 @@ class ConfigLoader: | |||||||
|                         self.update_from_file(env_file) |                         self.update_from_file(env_file) | ||||||
|         self.update_from_env() |         self.update_from_env() | ||||||
|  |  | ||||||
|     def _log(self, level: str, message: str, **kwargs): |     def log(self, level: str, message: str, **kwargs): | ||||||
|         """Custom Log method, we want to ensure ConfigLoader always logs JSON even when |         """Custom Log method, we want to ensure ConfigLoader always logs JSON even when | ||||||
|         'structlog' or 'logging' hasn't been configured yet.""" |         'structlog' or 'logging' hasn't been configured yet.""" | ||||||
|         output = { |         output = { | ||||||
| @ -95,7 +95,7 @@ class ConfigLoader: | |||||||
|                 with open(url.path, "r", encoding="utf8") as _file: |                 with open(url.path, "r", encoding="utf8") as _file: | ||||||
|                     value = _file.read() |                     value = _file.read() | ||||||
|             except OSError as exc: |             except OSError as exc: | ||||||
|                 self._log("error", f"Failed to read config value from {url.path}: {exc}") |                 self.log("error", f"Failed to read config value from {url.path}: {exc}") | ||||||
|                 value = url.query |                 value = url.query | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
| @ -105,12 +105,12 @@ class ConfigLoader: | |||||||
|             with open(path, encoding="utf8") as file: |             with open(path, encoding="utf8") as file: | ||||||
|                 try: |                 try: | ||||||
|                     self.update(self.__config, yaml.safe_load(file)) |                     self.update(self.__config, yaml.safe_load(file)) | ||||||
|                     self._log("debug", "Loaded config", file=path) |                     self.log("debug", "Loaded config", file=path) | ||||||
|                     self.loaded_file.append(path) |                     self.loaded_file.append(path) | ||||||
|                 except yaml.YAMLError as exc: |                 except yaml.YAMLError as exc: | ||||||
|                     raise ImproperlyConfigured from exc |                     raise ImproperlyConfigured from exc | ||||||
|         except PermissionError as exc: |         except PermissionError as exc: | ||||||
|             self._log( |             self.log( | ||||||
|                 "warning", |                 "warning", | ||||||
|                 "Permission denied while reading file", |                 "Permission denied while reading file", | ||||||
|                 path=path, |                 path=path, | ||||||
| @ -144,7 +144,7 @@ class ConfigLoader: | |||||||
|             current_obj[dot_parts[-1]] = value |             current_obj[dot_parts[-1]] = value | ||||||
|             idx += 1 |             idx += 1 | ||||||
|         if idx > 0: |         if idx > 0: | ||||||
|             self._log("debug", "Loaded environment variables", count=idx) |             self.log("debug", "Loaded environment variables", count=idx) | ||||||
|             self.update(self.__config, outer) |             self.update(self.__config, outer) | ||||||
|  |  | ||||||
|     @contextmanager |     @contextmanager | ||||||
| @ -152,7 +152,9 @@ class ConfigLoader: | |||||||
|         """Context manager for unittests to patch a value""" |         """Context manager for unittests to patch a value""" | ||||||
|         original_value = self.y(path) |         original_value = self.y(path) | ||||||
|         self.y_set(path, value) |         self.y_set(path, value) | ||||||
|  |         try: | ||||||
|             yield |             yield | ||||||
|  |         finally: | ||||||
|             self.y_set(path, original_value) |             self.y_set(path, original_value) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @ -178,7 +180,7 @@ class ConfigLoader: | |||||||
|             # pyright: reportGeneralTypeIssues=false |             # pyright: reportGeneralTypeIssues=false | ||||||
|             if comp not in root: |             if comp not in root: | ||||||
|                 root[comp] = {} |                 root[comp] = {} | ||||||
|             root = root.get(comp) |             root = root.get(comp, {}) | ||||||
|         root[path_parts[-1]] = value |         root[path_parts[-1]] = value | ||||||
|  |  | ||||||
|     def y_bool(self, path: str, default=False) -> bool: |     def y_bool(self, path: str, default=False) -> bool: | ||||||
|  | |||||||
| @ -36,7 +36,7 @@ error_reporting: | |||||||
|   enabled: false |   enabled: false | ||||||
|   environment: customer |   environment: customer | ||||||
|   send_pii: false |   send_pii: false | ||||||
|   sample_rate: 0.3 |   sample_rate: 0.1 | ||||||
|  |  | ||||||
| # Global email settings | # Global email settings | ||||||
| email: | email: | ||||||
| @ -80,3 +80,8 @@ default_token_length: 128 | |||||||
| impersonation: true | impersonation: true | ||||||
|  |  | ||||||
| blueprints_dir: /blueprints | blueprints_dir: /blueprints | ||||||
|  |  | ||||||
|  | web: | ||||||
|  |   # No default here as it's set dynamically | ||||||
|  |   # workers: 2 | ||||||
|  |   threads: 4 | ||||||
|  | |||||||
| @ -1,16 +1,20 @@ | |||||||
| """authentik expression policy evaluator""" | """authentik expression policy evaluator""" | ||||||
| import re | import re | ||||||
|  | from ipaddress import ip_address, ip_network | ||||||
| from textwrap import indent | from textwrap import indent | ||||||
| from typing import Any, Iterable, Optional | from typing import Any, Iterable, Optional | ||||||
|  |  | ||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
|  | from django_otp import devices_for_user | ||||||
| from rest_framework.serializers import ValidationError | from rest_framework.serializers import ValidationError | ||||||
| from sentry_sdk.hub import Hub | from sentry_sdk.hub import Hub | ||||||
| from sentry_sdk.tracing import Span | from sentry_sdk.tracing import Span | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
|  | from authentik.events.models import Event | ||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
|  | from authentik.policies.types import PolicyRequest | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -26,7 +30,8 @@ class BaseEvaluator: | |||||||
|     # Filename used for exec |     # Filename used for exec | ||||||
|     _filename: str |     _filename: str | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self, filename: Optional[str] = None): | ||||||
|  |         self._filename = filename if filename else "BaseEvaluator" | ||||||
|         # update website/docs/expressions/_objects.md |         # update website/docs/expressions/_objects.md | ||||||
|         # update website/docs/expressions/_functions.md |         # update website/docs/expressions/_functions.md | ||||||
|         self._globals = { |         self._globals = { | ||||||
| @ -35,11 +40,14 @@ class BaseEvaluator: | |||||||
|             "list_flatten": BaseEvaluator.expr_flatten, |             "list_flatten": BaseEvaluator.expr_flatten, | ||||||
|             "ak_is_group_member": BaseEvaluator.expr_is_group_member, |             "ak_is_group_member": BaseEvaluator.expr_is_group_member, | ||||||
|             "ak_user_by": BaseEvaluator.expr_user_by, |             "ak_user_by": BaseEvaluator.expr_user_by, | ||||||
|             "ak_logger": get_logger(), |             "ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator, | ||||||
|  |             "ak_create_event": self.expr_event_create, | ||||||
|  |             "ak_logger": get_logger(self._filename), | ||||||
|             "requests": get_http_session(), |             "requests": get_http_session(), | ||||||
|  |             "ip_address": ip_address, | ||||||
|  |             "ip_network": ip_network, | ||||||
|         } |         } | ||||||
|         self._context = {} |         self._context = {} | ||||||
|         self._filename = "BaseEvalautor" |  | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def expr_flatten(value: list[Any] | Any) -> Optional[Any]: |     def expr_flatten(value: list[Any] | Any) -> Optional[Any]: | ||||||
| @ -60,6 +68,11 @@ class BaseEvaluator: | |||||||
|         """Expression Filter to run re.sub""" |         """Expression Filter to run re.sub""" | ||||||
|         return re.sub(regex, repl, value) |         return re.sub(regex, repl, value) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def expr_is_group_member(user: User, **group_filters) -> bool: | ||||||
|  |         """Check if `user` is member of group with name `group_name`""" | ||||||
|  |         return user.ak_groups.filter(**group_filters).exists() | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def expr_user_by(**filters) -> Optional[User]: |     def expr_user_by(**filters) -> Optional[User]: | ||||||
|         """Get user by filters""" |         """Get user by filters""" | ||||||
| @ -72,15 +85,37 @@ class BaseEvaluator: | |||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def expr_is_group_member(user: User, **group_filters) -> bool: |     def expr_func_user_has_authenticator(user: User, device_type: Optional[str] = None) -> bool: | ||||||
|         """Check if `user` is member of group with name `group_name`""" |         """Check if a user has any authenticator devices, optionally matching *device_type*""" | ||||||
|         return user.ak_groups.filter(**group_filters).exists() |         user_devices = devices_for_user(user) | ||||||
|  |         if device_type: | ||||||
|  |             for device in user_devices: | ||||||
|  |                 device_class = device.__class__.__name__.lower().replace("device", "") | ||||||
|  |                 if device_class == device_type: | ||||||
|  |                     return True | ||||||
|  |             return False | ||||||
|  |         return len(list(user_devices)) > 0 | ||||||
|  |  | ||||||
|  |     def expr_event_create(self, action: str, **kwargs): | ||||||
|  |         """Create event with supplied data and try to extract as much relevant data | ||||||
|  |         from the context""" | ||||||
|  |         kwargs["context"] = self._context | ||||||
|  |         event = Event.new( | ||||||
|  |             action, | ||||||
|  |             app=self._filename, | ||||||
|  |             **kwargs, | ||||||
|  |         ) | ||||||
|  |         if "request" in self._context and isinstance(PolicyRequest, self._context["request"]): | ||||||
|  |             policy_request: PolicyRequest = self._context["request"] | ||||||
|  |             if policy_request.http_request: | ||||||
|  |                 event.from_http(policy_request) | ||||||
|  |                 return | ||||||
|  |         event.save() | ||||||
|  |  | ||||||
|     def wrap_expression(self, expression: str, params: Iterable[str]) -> str: |     def wrap_expression(self, expression: str, params: Iterable[str]) -> str: | ||||||
|         """Wrap expression in a function, call it, and save the result as `result`""" |         """Wrap expression in a function, call it, and save the result as `result`""" | ||||||
|         handler_signature = ",".join(params) |         handler_signature = ",".join(params) | ||||||
|         full_expression = "" |         full_expression = "" | ||||||
|         full_expression += "from ipaddress import ip_address, ip_network\n" |  | ||||||
|         full_expression += f"def handler({handler_signature}):\n" |         full_expression += f"def handler({handler_signature}):\n" | ||||||
|         full_expression += indent(expression, "    ") |         full_expression += indent(expression, "    ") | ||||||
|         full_expression += f"\nresult = handler({handler_signature})" |         full_expression += f"\nresult = handler({handler_signature})" | ||||||
|  | |||||||
| @ -95,7 +95,7 @@ def traces_sampler(sampling_context: dict) -> float: | |||||||
|     # Ignore all healthcheck routes |     # Ignore all healthcheck routes | ||||||
|     if path.startswith("/-/health") or path.startswith("/-/metrics"): |     if path.startswith("/-/health") or path.startswith("/-/metrics"): | ||||||
|         return 0 |         return 0 | ||||||
|     return float(CONFIG.y("error_reporting.sample_rate", 0.5)) |     return float(CONFIG.y("error_reporting.sample_rate", 0.1)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def before_send(event: dict, hint: dict) -> Optional[dict]: | def before_send(event: dict, hint: dict) -> Optional[dict]: | ||||||
|  | |||||||
| @ -20,8 +20,7 @@ def model_tester_factory(test_model: type[Stage]) -> Callable: | |||||||
|         try: |         try: | ||||||
|             model_class = None |             model_class = None | ||||||
|             if test_model._meta.abstract: |             if test_model._meta.abstract: | ||||||
|                 model_class = test_model.__bases__[0]() |                 return | ||||||
|             else: |  | ||||||
|             model_class = test_model() |             model_class = test_model() | ||||||
|             self.assertTrue(issubclass(model_class.serializer, BaseSerializer)) |             self.assertTrue(issubclass(model_class.serializer, BaseSerializer)) | ||||||
|         except NotImplementedError: |         except NotImplementedError: | ||||||
|  | |||||||
| @ -12,5 +12,4 @@ class TestReflectionUtils(TestCase): | |||||||
|  |  | ||||||
|     def test_path_to_class(self): |     def test_path_to_class(self): | ||||||
|         """Test path_to_class""" |         """Test path_to_class""" | ||||||
|         self.assertIsNone(path_to_class(None)) |  | ||||||
|         self.assertEqual(path_to_class("datetime.datetime"), datetime) |         self.assertEqual(path_to_class("datetime.datetime"), datetime) | ||||||
|  | |||||||
| @ -29,10 +29,8 @@ def class_to_path(cls: type) -> str: | |||||||
|     return f"{cls.__module__}.{cls.__name__}" |     return f"{cls.__module__}.{cls.__name__}" | ||||||
|  |  | ||||||
|  |  | ||||||
| def path_to_class(path: str | None) -> type | None: | def path_to_class(path: str = "") -> type: | ||||||
|     """Import module and return class""" |     """Import module and return class""" | ||||||
|     if not path: |  | ||||||
|         return None |  | ||||||
|     parts = path.split(".") |     parts = path.split(".") | ||||||
|     package = ".".join(parts[:-1]) |     package = ".".join(parts[:-1]) | ||||||
|     _class = getattr(import_module(package), parts[-1]) |     _class = getattr(import_module(package), parts[-1]) | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ def bad_request_message( | |||||||
|     request: HttpRequest, |     request: HttpRequest, | ||||||
|     message: str, |     message: str, | ||||||
|     title="Bad Request", |     title="Bad Request", | ||||||
|     template="error/generic.html", |     template="if/error.html", | ||||||
| ) -> TemplateResponse: | ) -> TemplateResponse: | ||||||
|     """Return generic error page with message, with status code set to 400""" |     """Return generic error page with message, with status code set to 400""" | ||||||
|     return TemplateResponse( |     return TemplateResponse( | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
| from prometheus_client import Gauge | from prometheus_client import Gauge | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ from enum import IntEnum | |||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
|  |  | ||||||
| from channels.exceptions import DenyConnection | from channels.exceptions import DenyConnection | ||||||
| from dacite import from_dict | from dacite.core import from_dict | ||||||
| from dacite.data import Data | from dacite.data import Data | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
| from dataclasses import asdict, dataclass, field | from dataclasses import asdict, dataclass, field | ||||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
| from dacite import from_dict | from dacite.core import from_dict | ||||||
| from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi | from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi | ||||||
|  |  | ||||||
| from authentik.outposts.controllers.base import FIELD_MANAGER | from authentik.outposts.controllers.base import FIELD_MANAGER | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ from datetime import datetime | |||||||
| from typing import Iterable, Optional | from typing import Iterable, Optional | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from dacite import from_dict | from dacite.core import from_dict | ||||||
| from django.contrib.auth.models import Permission | from django.contrib.auth.models import Permission | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db import IntegrityError, models, transaction | from django.db import IntegrityError, models, transaction | ||||||
| @ -74,7 +74,7 @@ class OutpostConfig: | |||||||
|     kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls") |     kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls") | ||||||
|     kubernetes_service_type: str = field(default="ClusterIP") |     kubernetes_service_type: str = field(default="ClusterIP") | ||||||
|     kubernetes_disabled_components: list[str] = field(default_factory=list) |     kubernetes_disabled_components: list[str] = field(default_factory=list) | ||||||
|     kubernetes_image_pull_secrets: Optional[list[str]] = field(default_factory=list) |     kubernetes_image_pull_secrets: list[str] = field(default_factory=list) | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutpostModel(Model): | class OutpostModel(Model): | ||||||
|  | |||||||
| @ -1,6 +1,5 @@ | |||||||
| """outpost tasks""" | """outpost tasks""" | ||||||
| from os import R_OK, access | from os import R_OK, access | ||||||
| from os.path import expanduser |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from socket import gethostname | from socket import gethostname | ||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
| @ -75,10 +74,14 @@ def outpost_service_connection_state(connection_pk: Any): | |||||||
|     ) |     ) | ||||||
|     if not connection: |     if not connection: | ||||||
|         return |         return | ||||||
|  |     cls = None | ||||||
|     if isinstance(connection, DockerServiceConnection): |     if isinstance(connection, DockerServiceConnection): | ||||||
|         cls = DockerClient |         cls = DockerClient | ||||||
|     if isinstance(connection, KubernetesServiceConnection): |     if isinstance(connection, KubernetesServiceConnection): | ||||||
|         cls = KubernetesClient |         cls = KubernetesClient | ||||||
|  |     if not cls: | ||||||
|  |         LOGGER.warning("No class found for service connection", connection=connection) | ||||||
|  |         return | ||||||
|     try: |     try: | ||||||
|         with cls(connection) as client: |         with cls(connection) as client: | ||||||
|             state = client.fetch_state() |             state = client.fetch_state() | ||||||
| @ -240,25 +243,25 @@ def _outpost_single_update(outpost: Outpost, layer=None): | |||||||
| def outpost_local_connection(): | def outpost_local_connection(): | ||||||
|     """Checks the local environment and create Service connections.""" |     """Checks the local environment and create Service connections.""" | ||||||
|     if not CONFIG.y_bool("outposts.discover"): |     if not CONFIG.y_bool("outposts.discover"): | ||||||
|         LOGGER.debug("Outpost integration discovery is disabled") |         LOGGER.info("Outpost integration discovery is disabled") | ||||||
|         return |         return | ||||||
|     # Explicitly check against token filename, as that's |     # Explicitly check against token filename, as that's | ||||||
|     # only present when the integration is enabled |     # only present when the integration is enabled | ||||||
|     if Path(SERVICE_TOKEN_FILENAME).exists(): |     if Path(SERVICE_TOKEN_FILENAME).exists(): | ||||||
|         LOGGER.debug("Detected in-cluster Kubernetes Config") |         LOGGER.info("Detected in-cluster Kubernetes Config") | ||||||
|         if not KubernetesServiceConnection.objects.filter(local=True).exists(): |         if not KubernetesServiceConnection.objects.filter(local=True).exists(): | ||||||
|             LOGGER.debug("Created Service Connection for in-cluster") |             LOGGER.debug("Created Service Connection for in-cluster") | ||||||
|             KubernetesServiceConnection.objects.create( |             KubernetesServiceConnection.objects.create( | ||||||
|                 name="Local Kubernetes Cluster", local=True, kubeconfig={} |                 name="Local Kubernetes Cluster", local=True, kubeconfig={} | ||||||
|             ) |             ) | ||||||
|     # For development, check for the existence of a kubeconfig file |     # For development, check for the existence of a kubeconfig file | ||||||
|     kubeconfig_path = expanduser(KUBE_CONFIG_DEFAULT_LOCATION) |     kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() | ||||||
|     if Path(kubeconfig_path).exists(): |     if kubeconfig_path.exists(): | ||||||
|         LOGGER.debug("Detected kubeconfig") |         LOGGER.info("Detected kubeconfig") | ||||||
|         kubeconfig_local_name = f"k8s-{gethostname()}" |         kubeconfig_local_name = f"k8s-{gethostname()}" | ||||||
|         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): |         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): | ||||||
|             LOGGER.debug("Creating kubeconfig Service Connection") |             LOGGER.debug("Creating kubeconfig Service Connection") | ||||||
|             with open(kubeconfig_path, "r", encoding="utf8") as _kubeconfig: |             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: | ||||||
|                 KubernetesServiceConnection.objects.create( |                 KubernetesServiceConnection.objects.create( | ||||||
|                     name=kubeconfig_local_name, |                     name=kubeconfig_local_name, | ||||||
|                     kubeconfig=yaml.safe_load(_kubeconfig), |                     kubeconfig=yaml.safe_load(_kubeconfig), | ||||||
| @ -266,7 +269,7 @@ def outpost_local_connection(): | |||||||
|     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path |     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path | ||||||
|     socket = Path(unix_socket_path) |     socket = Path(unix_socket_path) | ||||||
|     if socket.exists() and access(socket, R_OK): |     if socket.exists() and access(socket, R_OK): | ||||||
|         LOGGER.debug("Detected local docker socket") |         LOGGER.info("Detected local docker socket") | ||||||
|         if len(DockerServiceConnection.objects.filter(local=True)) == 0: |         if len(DockerServiceConnection.objects.filter(local=True)) == 0: | ||||||
|             LOGGER.debug("Created Service Connection for docker") |             LOGGER.debug("Created Service Connection for docker") | ||||||
|             DockerServiceConnection.objects.create( |             DockerServiceConnection.objects.create( | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ from channels.testing import WebsocketCommunicator | |||||||
| from django.test import TransactionTestCase | from django.test import TransactionTestCase | ||||||
|  |  | ||||||
| from authentik import __version__ | from authentik import __version__ | ||||||
| from authentik.flows.models import Flow, FlowDesignation | from authentik.core.tests.utils import create_test_flow | ||||||
| from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction | from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction | ||||||
| from authentik.outposts.models import Outpost, OutpostType | from authentik.outposts.models import Outpost, OutpostType | ||||||
| from authentik.providers.proxy.models import ProxyProvider | from authentik.providers.proxy.models import ProxyProvider | ||||||
| @ -21,9 +21,7 @@ class TestOutpostWS(TransactionTestCase): | |||||||
|             name="test", |             name="test", | ||||||
|             internal_host="http://localhost", |             internal_host="http://localhost", | ||||||
|             external_host="http://localhost", |             external_host="http://localhost", | ||||||
|             authorization_flow=Flow.objects.create( |             authorization_flow=create_test_flow(), | ||||||
|                 name="foo", slug="foo", designation=FlowDesignation.AUTHORIZATION |  | ||||||
|             ), |  | ||||||
|         ) |         ) | ||||||
|         self.outpost: Outpost = Outpost.objects.create( |         self.outpost: Outpost = Outpost.objects.create( | ||||||
|             name="test", |             name="test", | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """authentik policies app config""" | """authentik policies app config""" | ||||||
| from prometheus_client import Gauge, Histogram | from prometheus_client import Gauge, Histogram | ||||||
|  |  | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
| GAUGE_POLICIES_CACHED = Gauge( | GAUGE_POLICIES_CACHED = Gauge( | ||||||
|     "authentik_policies_cached", |     "authentik_policies_cached", | ||||||
|  | |||||||
| @ -1,12 +1,10 @@ | |||||||
| """authentik expression policy evaluator""" | """authentik expression policy evaluator""" | ||||||
| from ipaddress import ip_address, ip_network | from ipaddress import ip_address | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING, Optional | ||||||
|  |  | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from django_otp import devices_for_user |  | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import User |  | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_SSO | from authentik.flows.planner import PLAN_CONTEXT_SSO | ||||||
| from authentik.lib.expression.evaluator import BaseEvaluator | from authentik.lib.expression.evaluator import BaseEvaluator | ||||||
| from authentik.lib.utils.http import get_client_ip | from authentik.lib.utils.http import get_client_ip | ||||||
| @ -27,16 +25,14 @@ class PolicyEvaluator(BaseEvaluator): | |||||||
|  |  | ||||||
|     policy: Optional["ExpressionPolicy"] = None |     policy: Optional["ExpressionPolicy"] = None | ||||||
|  |  | ||||||
|     def __init__(self, policy_name: str): |     def __init__(self, policy_name: Optional[str] = None): | ||||||
|         super().__init__() |         super().__init__(policy_name or "PolicyEvaluator") | ||||||
|         self._messages = [] |         self._messages = [] | ||||||
|         self._context["ak_logger"] = get_logger(policy_name) |         # update website/docs/expressions/_objects.md | ||||||
|  |         # update website/docs/expressions/_functions.md | ||||||
|         self._context["ak_message"] = self.expr_func_message |         self._context["ak_message"] = self.expr_func_message | ||||||
|         self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator |         self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator | ||||||
|         self._context["ak_call_policy"] = self.expr_func_call_policy |         self._context["ak_call_policy"] = self.expr_func_call_policy | ||||||
|         self._context["ip_address"] = ip_address |  | ||||||
|         self._context["ip_network"] = ip_network |  | ||||||
|         self._filename = policy_name or "PolicyEvaluator" |  | ||||||
|  |  | ||||||
|     def expr_func_message(self, message: str): |     def expr_func_message(self, message: str): | ||||||
|         """Wrapper to append to messages list, which is returned with PolicyResult""" |         """Wrapper to append to messages list, which is returned with PolicyResult""" | ||||||
| @ -52,19 +48,6 @@ class PolicyEvaluator(BaseEvaluator): | |||||||
|         proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) |         proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) | ||||||
|         return proc.profiling_wrapper() |         return proc.profiling_wrapper() | ||||||
|  |  | ||||||
|     def expr_func_user_has_authenticator( |  | ||||||
|         self, user: User, device_type: Optional[str] = None |  | ||||||
|     ) -> bool: |  | ||||||
|         """Check if a user has any authenticator devices, optionally matching *device_type*""" |  | ||||||
|         user_devices = devices_for_user(user) |  | ||||||
|         if device_type: |  | ||||||
|             for device in user_devices: |  | ||||||
|                 device_class = device.__class__.__name__.lower().replace("device", "") |  | ||||||
|                 if device_class == device_type: |  | ||||||
|                     return True |  | ||||||
|             return False |  | ||||||
|         return len(list(user_devices)) > 0 |  | ||||||
|  |  | ||||||
|     def set_policy_request(self, request: PolicyRequest): |     def set_policy_request(self, request: PolicyRequest): | ||||||
|         """Update context based on policy request (if http request is given, update that too)""" |         """Update context based on policy request (if http request is given, update that too)""" | ||||||
|         # update website/docs/expressions/_objects.md |         # update website/docs/expressions/_objects.md | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| """Authentik reputation_policy app config""" | """Authentik reputation_policy app config""" | ||||||
| from authentik.blueprints.manager import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikPolicyReputationConfig(ManagedAppConfig): | class AuthentikPolicyReputationConfig(ManagedAppConfig): | ||||||
|  | |||||||
| @ -70,7 +70,6 @@ class PolicyAccessView(AccessMixin, View): | |||||||
|         # Check if user is unauthenticated, so we pass the application |         # Check if user is unauthenticated, so we pass the application | ||||||
|         # for the identification stage |         # for the identification stage | ||||||
|         if not request.user.is_authenticated: |         if not request.user.is_authenticated: | ||||||
|             LOGGER.warning("user not authenticated") |  | ||||||
|             return self.handle_no_permission() |             return self.handle_no_permission() | ||||||
|         # Check permissions |         # Check permissions | ||||||
|         result = self.user_has_access() |         result = self.user_has_access() | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	