Compare commits
	
		
			247 Commits
		
	
	
		
			version/20
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 2cfba36cb7 | |||
| 73bfe19f0e | |||
| 64e211c3a9 | |||
| 3c354db858 | |||
| 99df72d944 | |||
| 374b10b6e5 | |||
| 4c606fb0ba | |||
| f8502edd2b | |||
| aa48f6dd9d | |||
| 49b6aabb02 | |||
| f9d9b2716d | |||
| 7932b390dc | |||
| 241e36b2a6 | |||
| 81e820b6e6 | |||
| b16a3d5697 | |||
| 1583d53e54 | |||
| 909a7772dc | |||
| cd42b013ca | |||
| 22a2e65d30 | |||
| f3fe2c1b4a | |||
| 133accd033 | |||
| 47daaf969a | |||
| 9fb5092fdc | |||
| 05ccff4651 | |||
| 2d965afc5f | |||
| 522abfd2fd | |||
| 2d1bcf1aa7 | |||
| 24dbfcea3f | |||
| f008c03524 | |||
| 9ab066a863 | |||
| daa0417c38 | |||
| 067166d420 | |||
| 2bd10dbdee | |||
| 09795fa6fb | |||
| be64296494 | |||
| 02e2c117ac | |||
| 75434cc23f | |||
| 778e316690 | |||
| 3e0778fe31 | |||
| d2390eef89 | |||
| 5c542d5dc2 | |||
| 6b5b72ab4c | |||
| 93ae3c19b0 | |||
| a2ccdaca05 | |||
| 4a0e051c0b | |||
| 7ad2992fe7 | |||
| 223000804e | |||
| 7b0cb38c4b | |||
| 94d8465e92 | |||
| 4a91a7d2e2 | |||
| 369440652c | |||
| 493cdd5c0f | |||
| 9f5c019daa | |||
| ab28370f20 | |||
| 84c08dca41 | |||
| 6b8b596c92 | |||
| dc622a836f | |||
| 7b4faa0170 | |||
| 65f5f21de2 | |||
| d7e2b2e8a0 | |||
| c12c3877f6 | |||
| 9cc1b1213f | |||
| d0f5e77f77 | |||
| 68950ee8b7 | |||
| 53f224300b | |||
| 73019c0732 | |||
| 359da6db81 | |||
| 34928572db | |||
| 7f8afad528 | |||
| c1ad1e5c8b | |||
| b35b225453 | |||
| 0ff2ac7dc2 | |||
| 8b4a7666f0 | |||
| e477615b0f | |||
| e99b90912f | |||
| ae9dbf3014 | |||
| 7a50d5a4f8 | |||
| 4c4d87d3bd | |||
| a407334d3b | |||
| 5026cebf02 | |||
| 9770ba07c2 | |||
| 2e2ab55f9e | |||
| 28835fbca7 | |||
| aabb8af486 | |||
| 371f1878a8 | |||
| 662bd8af96 | |||
| 1b57cc3bf0 | |||
| 7517d612d0 | |||
| a29fabac42 | |||
| a0ff4ac038 | |||
| 4f6e3516b9 | |||
| 220f123b29 | |||
| 3e70b6443a | |||
| 792531a968 | |||
| de17fdb3ff | |||
| 19fdcba308 | |||
| 62f93c83d4 | |||
| 03a3f1bd6f | |||
| c2cb804ace | |||
| 11334cf638 | |||
| 60266b3345 | |||
| 2a4679e390 | |||
| 34e71351a6 | |||
| f13af32102 | |||
| 832b5f1058 | |||
| df789265f9 | |||
| b45fbbf20f | |||
| b4c500ce15 | |||
| 114226dc22 | |||
| 897446e5ac | |||
| eed958b132 | |||
| d0a69557d4 | |||
| 712d2b2aee | |||
| 2293cecd24 | |||
| 4e8f0bfb66 | |||
| 578139811c | |||
| f806dea0f6 | |||
| 02ec42738c | |||
| 86468163f0 | |||
| 159e533e4a | |||
| 24ee2c6c05 | |||
| dd383d763f | |||
| df9d8e9d25 | |||
| 12c318f0b1 | |||
| fc6ed8e7f9 | |||
| f68ed3562e | |||
| 621245aece | |||
| f2f22719f8 | |||
| 242423cf3c | |||
| 8e7a456f74 | |||
| 3987f8e371 | |||
| 6fb531c482 | |||
| 159798a7d8 | |||
| 89f8962a23 | |||
| 3db0a5b3d1 | |||
| feb20c371c | |||
| 73081e4947 | |||
| d7c2d0c7f9 | |||
| d5c57ab251 | |||
| 5511458757 | |||
| d9775f2822 | |||
| 398eb23d31 | |||
| 14a7c9f967 | |||
| 3e11f0c0b3 | |||
| c055245a45 | |||
| b1ba8f60a1 | |||
| f8f37dc52c | |||
| 19c36d20b5 | |||
| abca435337 | |||
| b5ee81f4de | |||
| 13bb7682b2 | |||
| 99f0f556e1 | |||
| 54ba3e9616 | |||
| 58e3ca28be | |||
| c6bb41890e | |||
| a4556b3692 | |||
| d0b52812d5 | |||
| f7e689ff03 | |||
| 2c419bee09 | |||
| 3c4c1e9d65 | |||
| ef3be13fdb | |||
| d3466ceef8 | |||
| 5886688fae | |||
| c3c8cbf7ef | |||
| 251ab71c3a | |||
| a0c546023f | |||
| 83eaac375d | |||
| 2868331976 | |||
| 5c0986c538 | |||
| ac21dcc44f | |||
| e9e7e25b27 | |||
| 7dc7d51cfa | |||
| 3eb3a9eab9 | |||
| 30db3b543b | |||
| 11aee98e0b | |||
| a099b21671 | |||
| b9294fd9ad | |||
| 13a302cdad | |||
| e994a01e80 | |||
| d49431cfc7 | |||
| ce2ce38b59 | |||
| 2af4f28239 | |||
| 1419910b29 | |||
| 649835cc61 | |||
| 917c4ae835 | |||
| ca2fce8be2 | |||
| e70fcd1a6f | |||
| fd461d9a00 | |||
| 8ea45b4dfe | |||
| 514d8f4569 | |||
| 2a43326da5 | |||
| 06f4c0608f | |||
| dc8b8e8f13 | |||
| 9f172e7ad0 | |||
| 832d76ec2a | |||
| 2545d85e8e | |||
| 9264acd00e | |||
| 51dd9473ce | |||
| 15c34c6f1f | |||
| 0ab8f4eed7 | |||
| 070714abe4 | |||
| 810c04bacf | |||
| b624bf1cb7 | |||
| f56a4b00ce | |||
| 6ec0411930 | |||
| fb59969bce | |||
| eec9c46533 | |||
| a3afd47850 | |||
| 8ffae4505f | |||
| 0cc83c23c4 | |||
| 514c48a986 | |||
| fdb8fb4b4c | |||
| d8a68407f9 | |||
| 417156d659 | |||
| 9d58407e25 | |||
| f4441c9fcf | |||
| 0e9762072a | |||
| 0cfffa28ad | |||
| 1ad4c8fc29 | |||
| dd2facdc57 | |||
| 549dfa4c3a | |||
| 7f8ae24e8d | |||
| d05aeb91f2 | |||
| ab25216c1f | |||
| fb5eb7b868 | |||
| bda218f7fc | |||
| 198c940a80 | |||
| c900411d5a | |||
| 1adc6948b4 | |||
| e87236b285 | |||
| 846b63a17b | |||
| 48fb3c98ff | |||
| b488a6fec9 | |||
| f014bd5f30 | |||
| 03dd079e8c | |||
| 1281e842d1 | |||
| f7601d9571 | |||
| 4d9c9160e7 | |||
| ad1f913e54 | |||
| 3da0233c40 | |||
| ff788edd9b | |||
| aea0958f3f | |||
| 1f9e9f9ca0 | |||
| 98ffec87c0 | |||
| b52d5dccac | |||
| e96a4fa181 | |||
| c53b0830c4 | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2022.8.1 | ||||
| current_version = 2022.9.0 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) | ||||
| @ -17,4 +17,4 @@ tag_name = version/{new_version} | ||||
|  | ||||
| [bumpversion:file:internal/constants/constants.go] | ||||
|  | ||||
| [bumpversion:file:web/src/constants.ts] | ||||
| [bumpversion:file:web/src/common/constants.ts] | ||||
|  | ||||
| @ -11,38 +11,7 @@ runs: | ||||
|   steps: | ||||
|     - name: Generate config | ||||
|       id: ev | ||||
|       shell: python | ||||
|       run: | | ||||
|         """Helper script to get the actual branch name, docker safe""" | ||||
|         import os | ||||
|         from time import time | ||||
|  | ||||
|         env_pr_branch = "GITHUB_HEAD_REF" | ||||
|         default_branch = "GITHUB_REF" | ||||
|         sha = "GITHUB_SHA" | ||||
|  | ||||
|         branch_name = os.environ[default_branch] | ||||
|         if os.environ.get(env_pr_branch, "") != "": | ||||
|             branch_name = os.environ[env_pr_branch] | ||||
|  | ||||
|         should_build = str(os.environ.get("DOCKER_USERNAME", "") != "").lower() | ||||
|  | ||||
|         print("##[set-output name=branchName]%s" % branch_name) | ||||
|         print( | ||||
|             "##[set-output name=branchNameContainer]%s" | ||||
|             % branch_name.replace("refs/heads/", "").replace("/", "-") | ||||
|         ) | ||||
|         print("##[set-output name=timestamp]%s" % int(time())) | ||||
|         print("##[set-output name=sha]%s" % os.environ[sha]) | ||||
|         print("##[set-output name=shouldBuild]%s" % should_build) | ||||
|  | ||||
|         import configparser | ||||
|         parser = configparser.ConfigParser() | ||||
|         parser.read(".bumpversion.cfg") | ||||
|         version = parser.get("bumpversion", "current_version") | ||||
|         version_family = ".".join(version.split(".")[:-1]) | ||||
|         print("##[set-output name=version]%s" % version) | ||||
|         print("##[set-output name=versionFamily]%s" % version_family) | ||||
|       uses: ./.github/actions/docker-push-variables | ||||
|     - name: Find Comment | ||||
|       uses: peter-evans/find-comment@v2 | ||||
|       id: fc | ||||
| @ -83,8 +52,6 @@ runs: | ||||
|             image: | ||||
|                 repository: ghcr.io/goauthentik/dev-server | ||||
|                 tag: ${{ inputs.tag }} | ||||
|                 # pullPolicy: Always to ensure you always get the latest version | ||||
|                 pullPolicy: Always | ||||
|             ``` | ||||
|  | ||||
|             Afterwards, run the upgrade commands from the latest release notes. | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -27,7 +27,7 @@ runs: | ||||
|         docker-compose -f .github/actions/setup/docker-compose.yml up -d | ||||
|         poetry env use python3.10 | ||||
|         poetry install | ||||
|         npm install -g pyright@1.1.136 | ||||
|         cd web && npm ci | ||||
|     - name: Generate config | ||||
|       shell: poetry run python {0} | ||||
|       run: | | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | ||||
| keypair | ||||
| keypairs | ||||
| hass | ||||
| warmup | ||||
|  | ||||
							
								
								
									
										14
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -96,6 +96,8 @@ jobs: | ||||
|           testspace [unittest]unittest.xml --link=codecov | ||||
|       - if: ${{ always() }} | ||||
|         uses: codecov/codecov-action@v3 | ||||
|         with: | ||||
|           flags: unit | ||||
|   test-integration: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
| @ -117,6 +119,8 @@ jobs: | ||||
|           testspace [integration]unittest.xml --link=codecov | ||||
|       - if: ${{ always() }} | ||||
|         uses: codecov/codecov-action@v3 | ||||
|         with: | ||||
|           flags: integration | ||||
|   test-e2e-provider: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
| @ -139,7 +143,7 @@ jobs: | ||||
|         working-directory: web | ||||
|         run: | | ||||
|           npm ci | ||||
|           make -C .. gen-client-web | ||||
|           make -C .. gen-client-ts | ||||
|           npm run build | ||||
|       - name: run e2e | ||||
|         run: | | ||||
| @ -151,6 +155,8 @@ jobs: | ||||
|           testspace [e2e-provider]unittest.xml --link=codecov | ||||
|       - if: ${{ always() }} | ||||
|         uses: codecov/codecov-action@v3 | ||||
|         with: | ||||
|           flags: e2e | ||||
|   test-e2e-rest: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
| @ -173,7 +179,7 @@ jobs: | ||||
|         working-directory: web/ | ||||
|         run: | | ||||
|           npm ci | ||||
|           make -C .. gen-client-web | ||||
|           make -C .. gen-client-ts | ||||
|           npm run build | ||||
|       - name: run e2e | ||||
|         run: | | ||||
| @ -185,6 +191,8 @@ jobs: | ||||
|           testspace [e2e-rest]unittest.xml --link=codecov | ||||
|       - if: ${{ always() }} | ||||
|         uses: codecov/codecov-action@v3 | ||||
|         with: | ||||
|           flags: e2e | ||||
|   ci-core-mark: | ||||
|     needs: | ||||
|       - lint | ||||
| @ -240,4 +248,4 @@ jobs: | ||||
|         continue-on-error: true | ||||
|         uses: ./.github/actions/comment-pr-instructions | ||||
|         with: | ||||
|           tag: gh-${{ steps.ev.outputs.branchNameContainer }} | ||||
|           tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.sha }} | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							| @ -23,7 +23,7 @@ jobs: | ||||
|       - working-directory: web/ | ||||
|         run: npm ci | ||||
|       - name: Generate API | ||||
|         run: make gen-client-web | ||||
|         run: make gen-client-ts | ||||
|       - name: Eslint | ||||
|         working-directory: web/ | ||||
|         run: npm run lint | ||||
| @ -39,7 +39,7 @@ jobs: | ||||
|       - working-directory: web/ | ||||
|         run: npm ci | ||||
|       - name: Generate API | ||||
|         run: make gen-client-web | ||||
|         run: make gen-client-ts | ||||
|       - name: prettier | ||||
|         working-directory: web/ | ||||
|         run: npm run prettier-check | ||||
| @ -60,7 +60,7 @@ jobs: | ||||
|           cd node_modules/@goauthentik | ||||
|           ln -s ../../src/ web | ||||
|       - name: Generate API | ||||
|         run: make gen-client-web | ||||
|         run: make gen-client-ts | ||||
|       - name: lit-analyse | ||||
|         working-directory: web/ | ||||
|         run: npm run lit-analyse | ||||
| @ -86,7 +86,7 @@ jobs: | ||||
|       - working-directory: web/ | ||||
|         run: npm ci | ||||
|       - name: Generate API | ||||
|         run: make gen-client-web | ||||
|         run: make gen-client-ts | ||||
|       - name: build | ||||
|         working-directory: web/ | ||||
|         run: npm run build | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -138,7 +138,7 @@ jobs: | ||||
|           docker-compose pull -q | ||||
|           docker-compose up --no-start | ||||
|           docker-compose start postgresql redis | ||||
|           docker-compose run -u root server test | ||||
|           docker-compose run -u root server test-all | ||||
|   sentry-release: | ||||
|     needs: | ||||
|       - build-server | ||||
| @ -157,6 +157,7 @@ jobs: | ||||
|           docker cp ${container}:web/ . | ||||
|       - name: Create a Sentry.io release | ||||
|         uses: getsentry/action-release@v1 | ||||
|         continue-on-error: true | ||||
|         if: ${{ github.event_name == 'release' }} | ||||
|         env: | ||||
|           SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/web-api-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/web-api-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -10,13 +10,12 @@ jobs: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v3 | ||||
|       # Setup .npmrc file to publish to npm | ||||
|       - uses: actions/setup-node@v3.4.1 | ||||
|         with: | ||||
|           node-version: '16' | ||||
|           registry-url: 'https://registry.npmjs.org' | ||||
|       - name: Generate API Client | ||||
|         run: make gen-client-web | ||||
|         run: make gen-client-ts | ||||
|       - name: Publish package | ||||
|         working-directory: gen-ts-api/ | ||||
|         run: | | ||||
|  | ||||
| @ -19,7 +19,7 @@ WORKDIR /work/web | ||||
| RUN npm ci && npm run build | ||||
|  | ||||
| # Stage 3: Poetry to requirements.txt export | ||||
| FROM docker.io/python:3.10.6-slim-bullseye AS poetry-locker | ||||
| FROM docker.io/python:3.10.7-slim-bullseye AS poetry-locker | ||||
|  | ||||
| WORKDIR /work | ||||
| COPY ./pyproject.toml /work | ||||
| @ -30,7 +30,7 @@ RUN pip install --no-cache-dir poetry && \ | ||||
|     poetry export -f requirements.txt --dev --output requirements-dev.txt | ||||
|  | ||||
| # Stage 4: Build go proxy | ||||
| FROM docker.io/golang:1.19.0-bullseye AS go-builder | ||||
| FROM docker.io/golang:1.19.1-bullseye AS go-builder | ||||
|  | ||||
| WORKDIR /work | ||||
|  | ||||
| @ -46,7 +46,7 @@ COPY ./go.sum /work/go.sum | ||||
| RUN go build -o /work/authentik ./cmd/server/main.go | ||||
|  | ||||
| # Stage 5: Run | ||||
| FROM docker.io/python:3.10.6-slim-bullseye AS final-image | ||||
| FROM docker.io/python:3.10.7-slim-bullseye AS final-image | ||||
|  | ||||
| LABEL org.opencontainers.image.url https://goauthentik.io | ||||
| LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info. | ||||
|  | ||||
							
								
								
									
										52
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								Makefile
									
									
									
									
									
								
							| @ -28,7 +28,7 @@ test-docker: | ||||
| 	rm -f .env | ||||
|  | ||||
| test: | ||||
| 	coverage run manage.py test authentik | ||||
| 	coverage run manage.py test --keepdb authentik | ||||
| 	coverage html | ||||
| 	coverage report | ||||
|  | ||||
| @ -49,28 +49,50 @@ lint: | ||||
| 	bandit -r authentik tests lifecycle -x node_modules | ||||
| 	golangci-lint run -v | ||||
|  | ||||
| migrate: | ||||
| 	python -m lifecycle.migrate | ||||
|  | ||||
| run: | ||||
| 	go run -v cmd/server/main.go | ||||
|  | ||||
| i18n-extract: i18n-extract-core web-extract | ||||
|  | ||||
| i18n-extract-core: | ||||
| 	ak makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en | ||||
|  | ||||
| ######################### | ||||
| ## API Schema | ||||
| ######################### | ||||
|  | ||||
| gen-build: | ||||
| 	AUTHENTIK_DEBUG=true ak make_blueprint_schema > blueprints/schema.json | ||||
| 	AUTHENTIK_DEBUG=true ak spectacular --file schema.yml | ||||
|  | ||||
| gen-diff: | ||||
| 	git show $(shell git tag -l | tail -n 1):schema.yml > old_schema.yml | ||||
| 	docker run \ | ||||
| 		--rm -v ${PWD}:/local \ | ||||
| 		--user ${UID}:${GID} \ | ||||
| 		docker.io/openapitools/openapi-diff:2.0.1 \ | ||||
| 		--markdown /local/diff.md \ | ||||
| 		/local/old_schema.yml /local/schema.yml | ||||
| 	rm old_schema.yml | ||||
|  | ||||
| gen-clean: | ||||
| 	rm -rf web/api/src/ | ||||
| 	rm -rf api/ | ||||
|  | ||||
| gen-client-web: | ||||
| gen-client-ts: | ||||
| 	docker run \ | ||||
| 		--rm -v ${PWD}:/local \ | ||||
| 		--user ${UID}:${GID} \ | ||||
| 		openapitools/openapi-generator-cli:v6.0.0 generate \ | ||||
| 		docker.io/openapitools/openapi-generator-cli:v6.0.0 generate \ | ||||
| 		-i /local/schema.yml \ | ||||
| 		-g typescript-fetch \ | ||||
| 		-o /local/gen-ts-api \ | ||||
| 		--additional-properties=typescriptThreePlus=true,supportsES6=true,npmName=@goauthentik/api,npmVersion=${NPM_VERSION} | ||||
| 		--additional-properties=typescriptThreePlus=true,supportsES6=true,npmName=@goauthentik/api,npmVersion=${NPM_VERSION} \ | ||||
| 		--git-repo-id authentik \ | ||||
| 		--git-user-id goauthentik | ||||
| 	mkdir -p web/node_modules/@goauthentik/api | ||||
| 	\cp -fv scripts/web_api_readme.md gen-ts-api/README.md | ||||
| 	cd gen-ts-api && npm i | ||||
| @ -84,7 +106,7 @@ gen-client-go: | ||||
| 	docker run \ | ||||
| 		--rm -v ${PWD}:/local \ | ||||
| 		--user ${UID}:${GID} \ | ||||
| 		openapitools/openapi-generator-cli:v6.0.0 generate \ | ||||
| 		docker.io/openapitools/openapi-generator-cli:v6.0.0 generate \ | ||||
| 		-i /local/schema.yml \ | ||||
| 		-g go \ | ||||
| 		-o /local/gen-go-api \ | ||||
| @ -95,13 +117,7 @@ gen-client-go: | ||||
| gen-dev-config: | ||||
| 	python -m scripts.generate_config | ||||
|  | ||||
| gen: gen-build gen-clean gen-client-web | ||||
|  | ||||
| migrate: | ||||
| 	python -m lifecycle.migrate | ||||
|  | ||||
| run: | ||||
| 	go run -v cmd/server/main.go | ||||
| gen: gen-build gen-clean gen-client-ts | ||||
|  | ||||
| ######################### | ||||
| ## Web | ||||
| @ -148,25 +164,25 @@ website-watch: | ||||
|  | ||||
| # These targets are use by GitHub actions to allow usage of matrix | ||||
| # which makes the YAML File a lot smaller | ||||
|  | ||||
| PY_SOURCES=authentik tests lifecycle | ||||
| ci--meta-debug: | ||||
| 	python -V | ||||
| 	node --version | ||||
|  | ||||
| ci-pylint: ci--meta-debug | ||||
| 	pylint authentik tests lifecycle | ||||
| 	pylint $(PY_SOURCES) | ||||
|  | ||||
| ci-black: ci--meta-debug | ||||
| 	black --check authentik tests lifecycle | ||||
| 	black --check $(PY_SOURCES) | ||||
|  | ||||
| ci-isort: ci--meta-debug | ||||
| 	isort --check authentik tests lifecycle | ||||
| 	isort --check $(PY_SOURCES) | ||||
|  | ||||
| ci-bandit: ci--meta-debug | ||||
| 	bandit -r authentik tests lifecycle | ||||
| 	bandit -r $(PY_SOURCES) | ||||
|  | ||||
| ci-pyright: ci--meta-debug | ||||
| 	pyright e2e lifecycle | ||||
| 	./web/node_modules/.bin/pyright $(PY_SOURCES) | ||||
|  | ||||
| ci-pending-migrations: ci--meta-debug | ||||
| 	ak makemigrations --check | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| from os import environ | ||||
| from typing import Optional | ||||
|  | ||||
| __version__ = "2022.8.1" | ||||
| __version__ = "2022.9.0" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -18,13 +18,13 @@ class AppSerializer(PassiveSerializer): | ||||
|  | ||||
|  | ||||
| class AppsViewSet(ViewSet): | ||||
|     """Read-only view set list all installed apps""" | ||||
|     """Read-only view list all installed apps""" | ||||
|  | ||||
|     permission_classes = [IsAdminUser] | ||||
|  | ||||
|     @extend_schema(responses={200: AppSerializer(many=True)}) | ||||
|     def list(self, request: Request) -> Response: | ||||
|         """List current messages and pass into Serializer""" | ||||
|         """Read-only view list all installed apps""" | ||||
|         data = [] | ||||
|         for app in sorted(get_apps(), key=lambda app: app.name): | ||||
|             data.append({"name": app.name, "label": app.verbose_name}) | ||||
|  | ||||
| @ -5,7 +5,7 @@ from django.contrib import messages | ||||
| from django.http.response import Http404 | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.fields import CharField, ChoiceField, DateTimeField, ListField | ||||
| from rest_framework.permissions import IsAdminUser | ||||
| @ -58,7 +58,15 @@ class TaskViewSet(ViewSet): | ||||
|         responses={ | ||||
|             200: TaskSerializer(many=False), | ||||
|             404: OpenApiResponse(description="Task not found"), | ||||
|         } | ||||
|         }, | ||||
|         parameters=[ | ||||
|             OpenApiParameter( | ||||
|                 "id", | ||||
|                 type=OpenApiTypes.STR, | ||||
|                 location=OpenApiParameter.PATH, | ||||
|                 required=True, | ||||
|             ), | ||||
|         ], | ||||
|     ) | ||||
|     # pylint: disable=invalid-name | ||||
|     def retrieve(self, request: Request, pk=None) -> Response: | ||||
| @ -81,6 +89,14 @@ class TaskViewSet(ViewSet): | ||||
|             404: OpenApiResponse(description="Task not found"), | ||||
|             500: OpenApiResponse(description="Failed to retry task"), | ||||
|         }, | ||||
|         parameters=[ | ||||
|             OpenApiParameter( | ||||
|                 "id", | ||||
|                 type=OpenApiTypes.STR, | ||||
|                 location=OpenApiParameter.PATH, | ||||
|                 required=True, | ||||
|             ), | ||||
|         ], | ||||
|     ) | ||||
|     @action(detail=True, methods=["post"]) | ||||
|     # pylint: disable=invalid-name | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """authentik admin app config""" | ||||
| from prometheus_client import Gauge, Info | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
| PROM_INFO = Info("authentik_version", "Currently running authentik version") | ||||
| GAUGE_WORKERS = Gauge("authentik_admin_workers", "Currently connected workers") | ||||
|  | ||||
| @ -16,7 +16,7 @@ from authentik.providers.oauth2.models import RefreshToken | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| def validate_auth(header: bytes) -> str: | ||||
| def validate_auth(header: bytes) -> Optional[str]: | ||||
|     """Validate that the header is in a correct format, | ||||
|     returns type and credentials""" | ||||
|     auth_credentials = header.decode().strip() | ||||
|  | ||||
| @ -60,8 +60,28 @@ def postprocess_schema_responses(result, generator, **kwargs):  # noqa: W0613 | ||||
|  | ||||
|     for path in result["paths"].values(): | ||||
|         for method in path.values(): | ||||
|             method["responses"].setdefault("400", validation_error.ref) | ||||
|             method["responses"].setdefault("403", generic_error.ref) | ||||
|             method["responses"].setdefault( | ||||
|                 "400", | ||||
|                 { | ||||
|                     "content": { | ||||
|                         "application/json": { | ||||
|                             "schema": validation_error.ref, | ||||
|                         } | ||||
|                     }, | ||||
|                     "description": "", | ||||
|                 }, | ||||
|             ) | ||||
|             method["responses"].setdefault( | ||||
|                 "403", | ||||
|                 { | ||||
|                     "content": { | ||||
|                         "application/json": { | ||||
|                             "schema": generic_error.ref, | ||||
|                         } | ||||
|                     }, | ||||
|                     "description": "", | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|     result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) | ||||
|  | ||||
|  | ||||
| @ -1,8 +1,7 @@ | ||||
| """Serializer mixin for managed models""" | ||||
| from dataclasses import asdict | ||||
|  | ||||
| from drf_spectacular.utils import extend_schema, inline_serializer | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import CharField, DateTimeField, JSONField | ||||
| from rest_framework.permissions import IsAdminUser | ||||
| from rest_framework.request import Request | ||||
| @ -11,11 +10,10 @@ from rest_framework.serializers import ListSerializer, ModelSerializer | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.models import BlueprintInstance | ||||
| from authentik.blueprints.v1.tasks import BlueprintFile, apply_blueprint, blueprints_find | ||||
| from authentik.blueprints.models import BlueprintInstance, BlueprintRetrievalFailed | ||||
| from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.events.utils import sanitize_dict | ||||
|  | ||||
|  | ||||
| class ManagedSerializer: | ||||
| @ -34,6 +32,14 @@ class MetadataSerializer(PassiveSerializer): | ||||
| class BlueprintInstanceSerializer(ModelSerializer): | ||||
|     """Info about a single blueprint instance file""" | ||||
|  | ||||
|     def validate_path(self, path: str) -> str: | ||||
|         """Ensure the path specified is retrievable""" | ||||
|         try: | ||||
|             BlueprintInstance(path=path).retrieve() | ||||
|         except BlueprintRetrievalFailed as exc: | ||||
|             raise ValidationError(exc) from exc | ||||
|         return path | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|         model = BlueprintInstance | ||||
| @ -85,8 +91,8 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | ||||
|     @action(detail=False, pagination_class=None, filter_backends=[]) | ||||
|     def available(self, request: Request) -> Response: | ||||
|         """Get blueprints""" | ||||
|         files: list[BlueprintFile] = blueprints_find.delay().get() | ||||
|         return Response([sanitize_dict(asdict(file)) for file in files]) | ||||
|         files: list[dict] = blueprints_find_dict.delay().get() | ||||
|         return Response(files) | ||||
|  | ||||
|     @permission_required("authentik_blueprints.view_blueprintinstance") | ||||
|     @extend_schema( | ||||
|  | ||||
| @ -1,6 +1,46 @@ | ||||
| """authentik Blueprints app""" | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from importlib import import_module | ||||
| from inspect import ismethod | ||||
|  | ||||
| from django.apps import AppConfig | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
|  | ||||
| class ManagedAppConfig(AppConfig): | ||||
|     """Basic reconciliation logic for apps""" | ||||
|  | ||||
|     _logger: BoundLogger | ||||
|  | ||||
|     def __init__(self, app_name: str, *args, **kwargs) -> None: | ||||
|         super().__init__(app_name, *args, **kwargs) | ||||
|         self._logger = get_logger().bind(app_name=app_name) | ||||
|  | ||||
|     def ready(self) -> None: | ||||
|         self.reconcile() | ||||
|         return super().ready() | ||||
|  | ||||
|     def import_module(self, path: str): | ||||
|         """Load module""" | ||||
|         import_module(path) | ||||
|  | ||||
|     def reconcile(self) -> None: | ||||
|         """reconcile ourselves""" | ||||
|         prefix = "reconcile_" | ||||
|         for meth_name in dir(self): | ||||
|             meth = getattr(self, meth_name) | ||||
|             if not ismethod(meth): | ||||
|                 continue | ||||
|             if not meth_name.startswith(prefix): | ||||
|                 continue | ||||
|             name = meth_name.replace(prefix, "") | ||||
|             try: | ||||
|                 self._logger.debug("Starting reconciler", name=name) | ||||
|                 meth() | ||||
|                 self._logger.debug("Successfully reconciled", name=name) | ||||
|             except (DatabaseError, ProgrammingError, InternalError) as exc: | ||||
|                 self._logger.debug("Failed to run reconcile", name=name, exc=exc) | ||||
|  | ||||
|  | ||||
| class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||
| @ -20,3 +60,7 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||
|         from authentik.blueprints.v1.tasks import blueprints_discover | ||||
|  | ||||
|         blueprints_discover.delay() | ||||
|  | ||||
|     def import_models(self): | ||||
|         super().import_models() | ||||
|         self.import_module("authentik.blueprints.v1.meta.apply_blueprint") | ||||
|  | ||||
| @ -1,7 +1,10 @@ | ||||
| """Apply blueprint from commandline""" | ||||
| from sys import exit as sys_exit | ||||
|  | ||||
| from django.core.management.base import BaseCommand, no_translations | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.blueprints.models import BlueprintInstance | ||||
| from authentik.blueprints.v1.importer import Importer | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| @ -14,14 +17,15 @@ class Command(BaseCommand): | ||||
|     def handle(self, *args, **options): | ||||
|         """Apply all blueprints in order, abort when one fails to import""" | ||||
|         for blueprint_path in options.get("blueprints", []): | ||||
|             with open(blueprint_path, "r", encoding="utf8") as blueprint_file: | ||||
|                 importer = Importer(blueprint_file.read()) | ||||
|                 valid, logs = importer.validate() | ||||
|                 if not valid: | ||||
|                     for log in logs: | ||||
|                         LOGGER.debug(**log) | ||||
|                     raise ValueError("blueprint invalid") | ||||
|                 importer.apply() | ||||
|             content = BlueprintInstance(path=blueprint_path).retrieve() | ||||
|             importer = Importer(content) | ||||
|             valid, logs = importer.validate() | ||||
|             if not valid: | ||||
|                 for log in logs: | ||||
|                     getattr(LOGGER, log.pop("log_level"))(**log) | ||||
|                 self.stderr.write("blueprint invalid") | ||||
|                 sys_exit(1) | ||||
|             importer.apply() | ||||
|  | ||||
|     def add_arguments(self, parser): | ||||
|         parser.add_argument("blueprints", nargs="+", type=str) | ||||
|  | ||||
							
								
								
									
										17
									
								
								authentik/blueprints/management/commands/export_blueprint.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								authentik/blueprints/management/commands/export_blueprint.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | ||||
| """Export blueprint of current authentik install""" | ||||
| from django.core.management.base import BaseCommand, no_translations | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.blueprints.v1.exporter import Exporter | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class Command(BaseCommand): | ||||
|     """Export blueprint of current authentik install""" | ||||
|  | ||||
|     @no_translations | ||||
|     def handle(self, *args, **options): | ||||
|         """Export blueprint of current authentik install""" | ||||
|         exporter = Exporter() | ||||
|         self.stdout.write(exporter.export_to_string()) | ||||
| @ -2,11 +2,11 @@ | ||||
| from json import dumps, loads | ||||
| from pathlib import Path | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.core.management.base import BaseCommand, no_translations | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.blueprints.v1.importer import is_model_allowed | ||||
| from authentik.blueprints.v1.meta.registry import registry | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -28,8 +28,9 @@ class Command(BaseCommand): | ||||
|     def set_model_allowed(self): | ||||
|         """Set model enum""" | ||||
|         model_names = [] | ||||
|         for model in apps.get_models(): | ||||
|         for model in registry.get_models(): | ||||
|             if not is_model_allowed(model): | ||||
|                 continue | ||||
|             model_names.append(f"{model._meta.app_label}.{model._meta.model_name}") | ||||
|         model_names.sort() | ||||
|         self.schema["properties"]["entries"]["items"]["properties"]["model"]["enum"] = model_names | ||||
|  | ||||
| @ -41,8 +41,7 @@ | ||||
|                 "$id": "#entry", | ||||
|                 "type": "object", | ||||
|                 "required": [ | ||||
|                     "model", | ||||
|                     "identifiers" | ||||
|                     "model" | ||||
|                 ], | ||||
|                 "properties": { | ||||
|                     "model": { | ||||
| @ -67,6 +66,7 @@ | ||||
|                     }, | ||||
|                     "identifiers": { | ||||
|                         "type": "object", | ||||
|                         "default": {}, | ||||
|                         "properties": { | ||||
|                             "pk": { | ||||
|                                 "description": "Commonly available field, may not exist on all models", | ||||
|  | ||||
| @ -1,44 +0,0 @@ | ||||
| """Managed objects manager""" | ||||
| from importlib import import_module | ||||
| from inspect import ismethod | ||||
|  | ||||
| from django.apps import AppConfig | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class ManagedAppConfig(AppConfig): | ||||
|     """Basic reconciliation logic for apps""" | ||||
|  | ||||
|     _logger: BoundLogger | ||||
|  | ||||
|     def __init__(self, app_name: str, *args, **kwargs) -> None: | ||||
|         super().__init__(app_name, *args, **kwargs) | ||||
|         self._logger = get_logger().bind(app_name=app_name) | ||||
|  | ||||
|     def ready(self) -> None: | ||||
|         self.reconcile() | ||||
|         return super().ready() | ||||
|  | ||||
|     def import_module(self, path: str): | ||||
|         """Load module""" | ||||
|         import_module(path) | ||||
|  | ||||
|     def reconcile(self) -> None: | ||||
|         """reconcile ourselves""" | ||||
|         prefix = "reconcile_" | ||||
|         for meth_name in dir(self): | ||||
|             meth = getattr(self, meth_name) | ||||
|             if not ismethod(meth): | ||||
|                 continue | ||||
|             if not meth_name.startswith(prefix): | ||||
|                 continue | ||||
|             name = meth_name.replace(prefix, "") | ||||
|             try: | ||||
|                 self._logger.debug("Starting reconciler", name=name) | ||||
|                 meth() | ||||
|                 self._logger.debug("Successfully reconciled", name=name) | ||||
|             except (DatabaseError, ProgrammingError, InternalError) as exc: | ||||
|                 self._logger.debug("Failed to run reconcile", name=name, exc=exc) | ||||
| @ -4,7 +4,7 @@ from glob import glob | ||||
| from pathlib import Path | ||||
|  | ||||
| import django.contrib.postgres.fields | ||||
| from dacite import from_dict | ||||
| from dacite.core import from_dict | ||||
| from django.apps.registry import Apps | ||||
| from django.conf import settings | ||||
| from django.db import migrations, models | ||||
| @ -113,7 +113,8 @@ class Migration(migrations.Migration): | ||||
|                             ("error", "Error"), | ||||
|                             ("orphaned", "Orphaned"), | ||||
|                             ("unknown", "Unknown"), | ||||
|                         ] | ||||
|                         ], | ||||
|                         default="unknown", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("enabled", models.BooleanField(default=True)), | ||||
|  | ||||
| @ -1,12 +1,36 @@ | ||||
| """Managed Object models""" | ||||
| """blueprint models""" | ||||
| from pathlib import Path | ||||
| from urllib.parse import urlparse | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.contrib.postgres.fields import ArrayField | ||||
| from django.db import models | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from opencontainers.distribution.reggie import ( | ||||
|     NewClient, | ||||
|     WithDebug, | ||||
|     WithDefaultName, | ||||
|     WithDigest, | ||||
|     WithReference, | ||||
|     WithUserAgent, | ||||
|     WithUsernamePassword, | ||||
| ) | ||||
| from requests.exceptions import RequestException | ||||
| from rest_framework.serializers import Serializer | ||||
| from structlog import get_logger | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.models import CreatedUpdatedModel, SerializerModel | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.lib.utils.http import authentik_user_agent | ||||
|  | ||||
| OCI_MEDIA_TYPE = "application/vnd.goauthentik.blueprint.v1+yaml" | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class BlueprintRetrievalFailed(SentryIgnoredException): | ||||
|     """Error raised when we're unable to fetch the blueprint contents, whether it be HTTP files | ||||
|     not being accessible or local files not being readable""" | ||||
|  | ||||
|  | ||||
| class ManagedModel(models.Model): | ||||
| @ -54,10 +78,70 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|     context = models.JSONField(default=dict) | ||||
|     last_applied = models.DateTimeField(auto_now=True) | ||||
|     last_applied_hash = models.TextField() | ||||
|     status = models.TextField(choices=BlueprintInstanceStatus.choices) | ||||
|     status = models.TextField( | ||||
|         choices=BlueprintInstanceStatus.choices, default=BlueprintInstanceStatus.UNKNOWN | ||||
|     ) | ||||
|     enabled = models.BooleanField(default=True) | ||||
|     managed_models = ArrayField(models.TextField(), default=list) | ||||
|  | ||||
|     def retrieve_oci(self) -> str: | ||||
|         """Get blueprint from an OCI registry""" | ||||
|         url = urlparse(self.path) | ||||
|         ref = "latest" | ||||
|         path = url.path[1:] | ||||
|         if ":" in url.path: | ||||
|             path, _, ref = path.partition(":") | ||||
|         client = NewClient( | ||||
|             f"{url.scheme}://{url.hostname}", | ||||
|             WithUserAgent(authentik_user_agent()), | ||||
|             WithUsernamePassword(url.username, url.password), | ||||
|             WithDefaultName(path), | ||||
|             WithDebug(True), | ||||
|         ) | ||||
|         LOGGER.info("Fetching OCI manifests for blueprint", instance=self) | ||||
|         manifest_request = client.NewRequest( | ||||
|             "GET", | ||||
|             "/v2/<name>/manifests/<reference>", | ||||
|             WithReference(ref), | ||||
|         ).SetHeader("Accept", "application/vnd.oci.image.manifest.v1+json") | ||||
|         try: | ||||
|             manifest_response = client.Do(manifest_request) | ||||
|             manifest_response.raise_for_status() | ||||
|         except RequestException as exc: | ||||
|             raise BlueprintRetrievalFailed(exc) from exc | ||||
|         manifest = manifest_response.json() | ||||
|         if "errors" in manifest: | ||||
|             raise BlueprintRetrievalFailed(manifest["errors"]) | ||||
|  | ||||
|         blob = None | ||||
|         for layer in manifest.get("layers", []): | ||||
|             if layer.get("mediaType", "") == OCI_MEDIA_TYPE: | ||||
|                 blob = layer.get("digest") | ||||
|                 LOGGER.debug("Found layer with matching media type", instance=self, blob=blob) | ||||
|         if not blob: | ||||
|             raise BlueprintRetrievalFailed("Blob not found") | ||||
|  | ||||
|         blob_request = client.NewRequest( | ||||
|             "GET", | ||||
|             "/v2/<name>/blobs/<digest>", | ||||
|             WithDigest(blob), | ||||
|         ) | ||||
|         try: | ||||
|             blob_response = client.Do(blob_request) | ||||
|             blob_response.raise_for_status() | ||||
|             return blob_response.text | ||||
|         except RequestException as exc: | ||||
|             raise BlueprintRetrievalFailed(exc) from exc | ||||
|  | ||||
|     def retrieve(self) -> str: | ||||
|         """Retrieve blueprint contents""" | ||||
|         full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path)) | ||||
|         if full_path.exists(): | ||||
|             LOGGER.debug("Blueprint path exists locally", instance=self) | ||||
|             with full_path.open("r", encoding="utf-8") as _file: | ||||
|                 return _file.read() | ||||
|         return self.retrieve_oci() | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> Serializer: | ||||
|         from authentik.blueprints.api import BlueprintInstanceSerializer | ||||
|  | ||||
| @ -5,7 +5,8 @@ from typing import Callable | ||||
|  | ||||
| from django.apps import apps | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.blueprints.models import BlueprintInstance | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
|  | ||||
| @ -19,11 +20,9 @@ def apply_blueprint(*files: str): | ||||
|  | ||||
|         @wraps(func) | ||||
|         def wrapper(*args, **kwargs): | ||||
|             base_path = Path(CONFIG.y("blueprints_dir")) | ||||
|             for file in files: | ||||
|                 full_path = Path(base_path, file) | ||||
|                 with full_path.open("r", encoding="utf-8") as _file: | ||||
|                     Importer(_file.read()).apply() | ||||
|                 content = BlueprintInstance(path=file).retrieve() | ||||
|                 Importer(content).apply() | ||||
|             return func(*args, **kwargs) | ||||
|  | ||||
|         return wrapper | ||||
|  | ||||
							
								
								
									
										97
									
								
								authentik/blueprints/tests/test_oci.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								authentik/blueprints/tests/test_oci.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,97 @@ | ||||
| """Test blueprints OCI""" | ||||
| from django.test import TransactionTestCase | ||||
| from requests_mock import Mocker | ||||
|  | ||||
| from authentik.blueprints.models import OCI_MEDIA_TYPE, BlueprintInstance, BlueprintRetrievalFailed | ||||
|  | ||||
|  | ||||
| class TestBlueprintOCI(TransactionTestCase): | ||||
|     """Test Blueprints OCI Tasks""" | ||||
|  | ||||
|     def test_successful(self): | ||||
|         """Successful retrieval""" | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||
|                 json={ | ||||
|                     "layers": [ | ||||
|                         { | ||||
|                             "mediaType": OCI_MEDIA_TYPE, | ||||
|                             "digest": "foo", | ||||
|                         } | ||||
|                     ] | ||||
|                 }, | ||||
|             ) | ||||
|             mocker.get("https://ghcr.io/v2/goauthentik/blueprints/test/blobs/foo", text="foo") | ||||
|  | ||||
|             self.assertEqual( | ||||
|                 BlueprintInstance( | ||||
|                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||
|                 ).retrieve_oci(), | ||||
|                 "foo", | ||||
|             ) | ||||
|  | ||||
|     def test_manifests_error(self): | ||||
|         """Test manifests request erroring""" | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", status_code=401 | ||||
|             ) | ||||
|  | ||||
|             with self.assertRaises(BlueprintRetrievalFailed): | ||||
|                 BlueprintInstance( | ||||
|                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||
|                 ).retrieve_oci() | ||||
|  | ||||
|     def test_manifests_error_response(self): | ||||
|         """Test manifests request erroring""" | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||
|                 json={"errors": ["foo"]}, | ||||
|             ) | ||||
|  | ||||
|             with self.assertRaises(BlueprintRetrievalFailed): | ||||
|                 BlueprintInstance( | ||||
|                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||
|                 ).retrieve_oci() | ||||
|  | ||||
|     def test_no_matching_blob(self): | ||||
|         """Successful retrieval""" | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||
|                 json={ | ||||
|                     "layers": [ | ||||
|                         { | ||||
|                             "mediaType": OCI_MEDIA_TYPE + "foo", | ||||
|                             "digest": "foo", | ||||
|                         } | ||||
|                     ] | ||||
|                 }, | ||||
|             ) | ||||
|             with self.assertRaises(BlueprintRetrievalFailed): | ||||
|                 BlueprintInstance( | ||||
|                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||
|                 ).retrieve_oci() | ||||
|  | ||||
|     def test_blob_error(self): | ||||
|         """Successful retrieval""" | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://ghcr.io/v2/goauthentik/blueprints/test/manifests/latest", | ||||
|                 json={ | ||||
|                     "layers": [ | ||||
|                         { | ||||
|                             "mediaType": OCI_MEDIA_TYPE, | ||||
|                             "digest": "foo", | ||||
|                         } | ||||
|                     ] | ||||
|                 }, | ||||
|             ) | ||||
|             mocker.get("https://ghcr.io/v2/goauthentik/blueprints/test/blobs/foo", status_code=401) | ||||
|  | ||||
|             with self.assertRaises(BlueprintRetrievalFailed): | ||||
|                 BlueprintInstance( | ||||
|                     path="https://ghcr.io/goauthentik/blueprints/test:latest" | ||||
|                 ).retrieve_oci() | ||||
| @ -1,17 +1,16 @@ | ||||
| """test packaged blueprints""" | ||||
| from glob import glob | ||||
| from pathlib import Path | ||||
| from typing import Callable | ||||
| 
 | ||||
| from django.test import TransactionTestCase | ||||
| from django.utils.text import slugify | ||||
| 
 | ||||
| from authentik.blueprints.models import BlueprintInstance | ||||
| from authentik.blueprints.tests import apply_blueprint | ||||
| from authentik.blueprints.v1.importer import Importer | ||||
| from authentik.tenants.models import Tenant | ||||
| 
 | ||||
| 
 | ||||
| class TestBundled(TransactionTestCase): | ||||
| class TestPackaged(TransactionTestCase): | ||||
|     """Empty class, test methods are added dynamically""" | ||||
| 
 | ||||
|     @apply_blueprint("default/90-default-tenant.yaml") | ||||
| @ -20,18 +19,20 @@ class TestBundled(TransactionTestCase): | ||||
|         self.assertTrue(Tenant.objects.filter(domain="authentik-default").exists()) | ||||
| 
 | ||||
| 
 | ||||
| def blueprint_tester(file_name: str) -> Callable: | ||||
| def blueprint_tester(file_name: Path) -> Callable: | ||||
|     """This is used instead of subTest for better visibility""" | ||||
| 
 | ||||
|     def tester(self: TestBundled): | ||||
|         with open(file_name, "r", encoding="utf8") as flow_yaml: | ||||
|             importer = Importer(flow_yaml.read()) | ||||
|     def tester(self: TestPackaged): | ||||
|         base = Path("blueprints/") | ||||
|         rel_path = Path(file_name).relative_to(base) | ||||
|         importer = Importer(BlueprintInstance(path=str(rel_path)).retrieve()) | ||||
|         self.assertTrue(importer.validate()[0]) | ||||
|         self.assertTrue(importer.apply()) | ||||
| 
 | ||||
|     return tester | ||||
| 
 | ||||
| 
 | ||||
| for flow_file in glob("blueprints/**/*.yaml", recursive=True): | ||||
|     method_name = slugify(Path(flow_file).stem).replace("-", "_").replace(".", "_") | ||||
|     setattr(TestBundled, f"test_flow_{method_name}", blueprint_tester(flow_file)) | ||||
| for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | ||||
|     if "local" in str(blueprint_file): | ||||
|         continue | ||||
|     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) | ||||
| @ -1,7 +1,7 @@ | ||||
| """Test blueprints v1""" | ||||
| from django.test import TransactionTestCase | ||||
|  | ||||
| from authentik.blueprints.v1.exporter import Exporter | ||||
| from authentik.blueprints.v1.exporter import FlowExporter | ||||
| from authentik.blueprints.v1.importer import Importer, transaction_rollback | ||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding | ||||
| from authentik.lib.generators import generate_id | ||||
| @ -70,7 +70,7 @@ class TestBlueprintsV1(TransactionTestCase): | ||||
|                 order=0, | ||||
|             ) | ||||
|  | ||||
|             exporter = Exporter(flow) | ||||
|             exporter = FlowExporter(flow) | ||||
|             export = exporter.export() | ||||
|             self.assertEqual(len(export.entries), 3) | ||||
|             export_yaml = exporter.export_to_string() | ||||
| @ -126,7 +126,7 @@ class TestBlueprintsV1(TransactionTestCase): | ||||
|             fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0) | ||||
|             PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0) | ||||
|  | ||||
|             exporter = Exporter(flow) | ||||
|             exporter = FlowExporter(flow) | ||||
|             export_yaml = exporter.export_to_string() | ||||
|  | ||||
|         importer = Importer(export_yaml) | ||||
| @ -169,7 +169,7 @@ class TestBlueprintsV1(TransactionTestCase): | ||||
|  | ||||
|             FlowStageBinding.objects.create(target=flow, stage=first_stage, order=0) | ||||
|  | ||||
|             exporter = Exporter(flow) | ||||
|             exporter = FlowExporter(flow) | ||||
|             export_yaml = exporter.export_to_string() | ||||
|  | ||||
|         importer = Importer(export_yaml) | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| """Test blueprints v1 tasks""" | ||||
| from hashlib import sha512 | ||||
| from tempfile import NamedTemporaryFile, mkdtemp | ||||
|  | ||||
| from django.test import TransactionTestCase | ||||
| @ -36,25 +37,32 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|     @CONFIG.patch("blueprints_dir", TMP) | ||||
|     def test_valid(self): | ||||
|         """Test valid file""" | ||||
|         blueprint_id = generate_id() | ||||
|         with NamedTemporaryFile(mode="w+", suffix=".yaml", dir=TMP) as file: | ||||
|             file.write( | ||||
|                 dump( | ||||
|                     { | ||||
|                         "version": 1, | ||||
|                         "entries": [], | ||||
|                         "metadata": { | ||||
|                             "name": blueprint_id, | ||||
|                         }, | ||||
|                     } | ||||
|                 ) | ||||
|             ) | ||||
|             file.seek(0) | ||||
|             file_hash = sha512(file.read().encode()).hexdigest() | ||||
|             file.flush() | ||||
|             blueprints_discover()  # pylint: disable=no-value-for-parameter | ||||
|             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() | ||||
|             self.assertEqual(instance.last_applied_hash, file_hash) | ||||
|             self.assertEqual( | ||||
|                 BlueprintInstance.objects.first().last_applied_hash, | ||||
|                 ( | ||||
|                     "e52bb445b03cd36057258dc9f0ce0fbed8278498ee1470e45315293e5f026d1b" | ||||
|                     "d1f9b3526871c0003f5c07be5c3316d9d4a08444bd8fed1b3f03294e51e44522" | ||||
|                 ), | ||||
|                 instance.metadata, | ||||
|                 { | ||||
|                     "name": blueprint_id, | ||||
|                     "labels": {}, | ||||
|                 }, | ||||
|             ) | ||||
|             self.assertEqual(BlueprintInstance.objects.first().metadata, {}) | ||||
|  | ||||
|     @CONFIG.patch("blueprints_dir", TMP) | ||||
|     def test_valid_updated(self): | ||||
|  | ||||
| @ -27,7 +27,7 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]: | ||||
|             continue | ||||
|         if _field.read_only: | ||||
|             data.pop(field_name, None) | ||||
|         if _field.default == data.get(field_name, None): | ||||
|         if _field.get_initial() == data.get(field_name, None): | ||||
|             data.pop(field_name, None) | ||||
|         if field_name.endswith("_set"): | ||||
|             data.pop(field_name, None) | ||||
| @ -35,21 +35,28 @@ def get_attrs(obj: SerializerModel) -> dict[str, Any]: | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class BlueprintEntry: | ||||
|     """Single entry of a bundle""" | ||||
| class BlueprintEntryState: | ||||
|     """State of a single instance""" | ||||
|  | ||||
|     instance: Optional[Model] = None | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class BlueprintEntry: | ||||
|     """Single entry of a blueprint""" | ||||
|  | ||||
|     identifiers: dict[str, Any] | ||||
|     model: str | ||||
|     identifiers: dict[str, Any] = field(default_factory=dict) | ||||
|     attrs: Optional[dict[str, Any]] = field(default_factory=dict) | ||||
|  | ||||
|     # pylint: disable=invalid-name | ||||
|     id: Optional[str] = None | ||||
|  | ||||
|     _instance: Optional[Model] = None | ||||
|     _state: BlueprintEntryState = field(default_factory=BlueprintEntryState) | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry": | ||||
|         """Convert a SerializerModel instance to a Bundle Entry""" | ||||
|         """Convert a SerializerModel instance to a blueprint Entry""" | ||||
|         identifiers = { | ||||
|             "pk": model.pk, | ||||
|         } | ||||
| @ -98,9 +105,9 @@ class Blueprint: | ||||
|  | ||||
|     version: int = field(default=1) | ||||
|     entries: list[BlueprintEntry] = field(default_factory=list) | ||||
|     context: dict = field(default_factory=dict) | ||||
|  | ||||
|     metadata: Optional[BlueprintMetadata] = field(default=None) | ||||
|     context: Optional[dict] = field(default_factory=dict) | ||||
|  | ||||
|  | ||||
| class YAMLTag: | ||||
| @ -123,15 +130,15 @@ class KeyOf(YAMLTag): | ||||
|  | ||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||
|         for _entry in blueprint.entries: | ||||
|             if _entry.id == self.id_from and _entry._instance: | ||||
|             if _entry.id == self.id_from and _entry._state.instance: | ||||
|                 # Special handling for PolicyBindingModels, as they'll have a different PK | ||||
|                 # which is used when creating policy bindings | ||||
|                 if ( | ||||
|                     isinstance(_entry._instance, PolicyBindingModel) | ||||
|                     isinstance(_entry._state.instance, PolicyBindingModel) | ||||
|                     and entry.model.lower() == "authentik_policies.policybinding" | ||||
|                 ): | ||||
|                     return _entry._instance.pbm_uuid | ||||
|                 return _entry._instance.pk | ||||
|                     return _entry._state.instance.pbm_uuid | ||||
|                 return _entry._state.instance.pk | ||||
|         raise ValueError( | ||||
|             f"KeyOf: failed to find entry with `id` of `{self.id_from}` and a model instance" | ||||
|         ) | ||||
| @ -176,8 +183,6 @@ class Format(YAMLTag): | ||||
|  | ||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||
|         try: | ||||
|             print(self.format_string) | ||||
|             print(self.args) | ||||
|             return self.format_string % tuple(self.args) | ||||
|         except TypeError as exc: | ||||
|             raise EntryInvalidError(exc) | ||||
| @ -225,7 +230,13 @@ class BlueprintDumper(SafeDumper): | ||||
|  | ||||
|     def represent(self, data) -> None: | ||||
|         if is_dataclass(data): | ||||
|             data = asdict(data) | ||||
|  | ||||
|             def factory(items): | ||||
|                 final_dict = dict(items) | ||||
|                 final_dict.pop("_state", None) | ||||
|                 return final_dict | ||||
|  | ||||
|             data = asdict(data, dict_factory=factory) | ||||
|         return super().represent(data) | ||||
|  | ||||
|  | ||||
| @ -242,3 +253,9 @@ class BlueprintLoader(SafeLoader): | ||||
|  | ||||
| class EntryInvalidError(SentryIgnoredException): | ||||
|     """Error raised when an entry is invalid""" | ||||
|  | ||||
|     serializer_errors: Optional[dict] | ||||
|  | ||||
|     def __init__(self, *args: object, serializer_errors: Optional[dict] = None) -> None: | ||||
|         super().__init__(*args) | ||||
|         self.serializer_errors = serializer_errors | ||||
|  | ||||
| @ -1,11 +1,24 @@ | ||||
| """Flow exporter""" | ||||
| from typing import Iterator | ||||
| """Blueprint exporter""" | ||||
| from typing import Iterable | ||||
| from uuid import UUID | ||||
|  | ||||
| from django.db.models import Q | ||||
| from django.apps import apps | ||||
| from django.contrib.auth import get_user_model | ||||
| from django.db.models import Model, Q, QuerySet | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext as _ | ||||
| from guardian.shortcuts import get_anonymous_user | ||||
| from yaml import dump | ||||
|  | ||||
| from authentik.blueprints.v1.common import Blueprint, BlueprintDumper, BlueprintEntry | ||||
| from authentik.blueprints.v1.common import ( | ||||
|     Blueprint, | ||||
|     BlueprintDumper, | ||||
|     BlueprintEntry, | ||||
|     BlueprintMetadata, | ||||
| ) | ||||
| from authentik.blueprints.v1.importer import is_model_allowed | ||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_GENERATED | ||||
| from authentik.events.models import Event | ||||
| from authentik.flows.models import Flow, FlowStageBinding, Stage | ||||
| from authentik.policies.models import Policy, PolicyBinding | ||||
| from authentik.stages.prompt.models import PromptStage | ||||
| @ -14,6 +27,55 @@ from authentik.stages.prompt.models import PromptStage | ||||
| class Exporter: | ||||
|     """Export flow with attached stages into yaml""" | ||||
|  | ||||
|     excluded_models: list[type[Model]] = [] | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.excluded_models = [ | ||||
|             Event, | ||||
|         ] | ||||
|  | ||||
|     def get_entries(self) -> Iterable[BlueprintEntry]: | ||||
|         """Get blueprint entries""" | ||||
|         for model in apps.get_models(): | ||||
|             if not is_model_allowed(model): | ||||
|                 continue | ||||
|             if model in self.excluded_models: | ||||
|                 continue | ||||
|             for obj in self.get_model_instances(model): | ||||
|                 yield BlueprintEntry.from_model(obj) | ||||
|  | ||||
|     def get_model_instances(self, model: type[Model]) -> QuerySet: | ||||
|         """Return a queryset for `model`. Can be used to filter some | ||||
|         objects on some models""" | ||||
|         if model == get_user_model(): | ||||
|             return model.objects.exclude(pk=get_anonymous_user().pk) | ||||
|         return model.objects.all() | ||||
|  | ||||
|     def _pre_export(self, blueprint: Blueprint): | ||||
|         """Hook to run anything pre-export""" | ||||
|  | ||||
|     def export(self) -> Blueprint: | ||||
|         """Create a list of all objects and create a blueprint""" | ||||
|         blueprint = Blueprint() | ||||
|         self._pre_export(blueprint) | ||||
|         blueprint.metadata = BlueprintMetadata( | ||||
|             name=_("authentik Export - %(date)s" % {"date": str(now())}), | ||||
|             labels={ | ||||
|                 LABEL_AUTHENTIK_GENERATED: "true", | ||||
|             }, | ||||
|         ) | ||||
|         blueprint.entries = list(self.get_entries()) | ||||
|         return blueprint | ||||
|  | ||||
|     def export_to_string(self) -> str: | ||||
|         """Call export and convert it to yaml""" | ||||
|         blueprint = self.export() | ||||
|         return dump(blueprint, Dumper=BlueprintDumper) | ||||
|  | ||||
|  | ||||
| class FlowExporter(Exporter): | ||||
|     """Exporter customised to only return objects related to `flow`""" | ||||
|  | ||||
|     flow: Flow | ||||
|     with_policies: bool | ||||
|     with_stage_prompts: bool | ||||
| @ -21,17 +83,20 @@ class Exporter: | ||||
|     pbm_uuids: list[UUID] | ||||
|  | ||||
|     def __init__(self, flow: Flow): | ||||
|         super().__init__() | ||||
|         self.flow = flow | ||||
|         self.with_policies = True | ||||
|         self.with_stage_prompts = True | ||||
|  | ||||
|     def _prepare_pbm(self): | ||||
|     def _pre_export(self, blueprint: Blueprint): | ||||
|         if not self.with_policies: | ||||
|             return | ||||
|         self.pbm_uuids = [self.flow.pbm_uuid] | ||||
|         self.pbm_uuids += FlowStageBinding.objects.filter(target=self.flow).values_list( | ||||
|             "pbm_uuid", flat=True | ||||
|         ) | ||||
|  | ||||
|     def walk_stages(self) -> Iterator[BlueprintEntry]: | ||||
|     def walk_stages(self) -> Iterable[BlueprintEntry]: | ||||
|         """Convert all stages attached to self.flow into BlueprintEntry objects""" | ||||
|         stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses() | ||||
|         for stage in stages: | ||||
| @ -39,13 +104,13 @@ class Exporter: | ||||
|                 pass | ||||
|             yield BlueprintEntry.from_model(stage, "name") | ||||
|  | ||||
|     def walk_stage_bindings(self) -> Iterator[BlueprintEntry]: | ||||
|     def walk_stage_bindings(self) -> Iterable[BlueprintEntry]: | ||||
|         """Convert all bindings attached to self.flow into BlueprintEntry objects""" | ||||
|         bindings = FlowStageBinding.objects.filter(target=self.flow).select_related() | ||||
|         for binding in bindings: | ||||
|             yield BlueprintEntry.from_model(binding, "target", "stage", "order") | ||||
|  | ||||
|     def walk_policies(self) -> Iterator[BlueprintEntry]: | ||||
|     def walk_policies(self) -> Iterable[BlueprintEntry]: | ||||
|         """Walk over all policies. This is done at the beginning of the export for stages that have | ||||
|         a direct foreign key to a policy.""" | ||||
|         # Special case for PromptStage as that has a direct M2M to policy, we have to ensure | ||||
| @ -56,37 +121,29 @@ class Exporter: | ||||
|         for policy in policies: | ||||
|             yield BlueprintEntry.from_model(policy) | ||||
|  | ||||
|     def walk_policy_bindings(self) -> Iterator[BlueprintEntry]: | ||||
|     def walk_policy_bindings(self) -> Iterable[BlueprintEntry]: | ||||
|         """Walk over all policybindings relative to us. This is run at the end of the export, as | ||||
|         we are sure all objects exist now.""" | ||||
|         bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related() | ||||
|         for binding in bindings: | ||||
|             yield BlueprintEntry.from_model(binding, "policy", "target", "order") | ||||
|  | ||||
|     def walk_stage_prompts(self) -> Iterator[BlueprintEntry]: | ||||
|     def walk_stage_prompts(self) -> Iterable[BlueprintEntry]: | ||||
|         """Walk over all prompts associated with any PromptStages""" | ||||
|         prompt_stages = PromptStage.objects.filter(flow=self.flow) | ||||
|         for stage in prompt_stages: | ||||
|             for prompt in stage.fields.all(): | ||||
|                 yield BlueprintEntry.from_model(prompt) | ||||
|  | ||||
|     def export(self) -> Blueprint: | ||||
|         """Create a list of all objects including the flow""" | ||||
|         if self.with_policies: | ||||
|             self._prepare_pbm() | ||||
|         bundle = Blueprint() | ||||
|         bundle.entries.append(BlueprintEntry.from_model(self.flow, "slug")) | ||||
|     def get_entries(self) -> Iterable[BlueprintEntry]: | ||||
|         entries = [] | ||||
|         entries.append(BlueprintEntry.from_model(self.flow, "slug")) | ||||
|         if self.with_stage_prompts: | ||||
|             bundle.entries.extend(self.walk_stage_prompts()) | ||||
|             entries.extend(self.walk_stage_prompts()) | ||||
|         if self.with_policies: | ||||
|             bundle.entries.extend(self.walk_policies()) | ||||
|         bundle.entries.extend(self.walk_stages()) | ||||
|         bundle.entries.extend(self.walk_stage_bindings()) | ||||
|             entries.extend(self.walk_policies()) | ||||
|         entries.extend(self.walk_stages()) | ||||
|         entries.extend(self.walk_stage_bindings()) | ||||
|         if self.with_policies: | ||||
|             bundle.entries.extend(self.walk_policy_bindings()) | ||||
|         return bundle | ||||
|  | ||||
|     def export_to_string(self) -> str: | ||||
|         """Call export and convert it to yaml""" | ||||
|         bundle = self.export() | ||||
|         return dump(bundle, Dumper=BlueprintDumper) | ||||
|             entries.extend(self.walk_policy_bindings()) | ||||
|         return entries | ||||
|  | ||||
| @ -3,10 +3,9 @@ from contextlib import contextmanager | ||||
| from copy import deepcopy | ||||
| from typing import Any, Optional | ||||
|  | ||||
| from dacite import from_dict | ||||
| from dacite.core import from_dict | ||||
| from dacite.exceptions import DaciteError | ||||
| from deepmerge import always_merger | ||||
| from django.apps import apps | ||||
| from django.db import transaction | ||||
| from django.db.models import Model | ||||
| from django.db.models.query_utils import Q | ||||
| @ -21,9 +20,11 @@ from yaml import load | ||||
| from authentik.blueprints.v1.common import ( | ||||
|     Blueprint, | ||||
|     BlueprintEntry, | ||||
|     BlueprintEntryState, | ||||
|     BlueprintLoader, | ||||
|     EntryInvalidError, | ||||
| ) | ||||
| from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry | ||||
| from authentik.core.models import ( | ||||
|     AuthenticatedSession, | ||||
|     PropertyMapping, | ||||
| @ -58,7 +59,7 @@ def is_model_allowed(model: type[Model]) -> bool: | ||||
|         # Classes that have other dependencies | ||||
|         AuthenticatedSession, | ||||
|     ) | ||||
|     return model not in excluded_models | ||||
|     return model not in excluded_models and issubclass(model, (SerializerModel, BaseMetaModel)) | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| @ -137,10 +138,20 @@ class Importer: | ||||
|     def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer: | ||||
|         """Validate a single entry""" | ||||
|         model_app_label, model_name = entry.model.split(".") | ||||
|         model: type[SerializerModel] = apps.get_model(model_app_label, model_name) | ||||
|         model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||
|         # Don't use isinstance since we don't want to check for inheritance | ||||
|         if not is_model_allowed(model): | ||||
|             raise EntryInvalidError(f"Model {model} not allowed") | ||||
|         if issubclass(model, BaseMetaModel): | ||||
|             serializer_class: type[Serializer] = model.serializer() | ||||
|             serializer = serializer_class(data=entry.get_attrs(self.__import)) | ||||
|             try: | ||||
|                 serializer.is_valid(raise_exception=True) | ||||
|             except ValidationError as exc: | ||||
|                 raise EntryInvalidError( | ||||
|                     f"Serializer errors {serializer.errors}", serializer_errors=serializer.errors | ||||
|                 ) from exc | ||||
|             return serializer | ||||
|         if entry.identifiers == {}: | ||||
|             raise EntryInvalidError("No identifiers") | ||||
|  | ||||
| @ -157,7 +168,7 @@ class Importer: | ||||
|         existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers)) | ||||
|  | ||||
|         serializer_kwargs = {} | ||||
|         if existing_models.exists(): | ||||
|         if not isinstance(model(), BaseMetaModel) and existing_models.exists(): | ||||
|             model_instance = existing_models.first() | ||||
|             self.logger.debug( | ||||
|                 "initialise serializer with instance", | ||||
| @ -168,7 +179,9 @@ class Importer: | ||||
|             serializer_kwargs["instance"] = model_instance | ||||
|             serializer_kwargs["partial"] = True | ||||
|         else: | ||||
|             self.logger.debug("initialise new instance", model=model, **updated_identifiers) | ||||
|             self.logger.debug( | ||||
|                 "initialised new serializer instance", model=model, **updated_identifiers | ||||
|             ) | ||||
|             model_instance = model() | ||||
|             # pk needs to be set on the model instance otherwise a new one will be generated | ||||
|             if "pk" in updated_identifiers: | ||||
| @ -182,7 +195,9 @@ class Importer: | ||||
|         try: | ||||
|             serializer.is_valid(raise_exception=True) | ||||
|         except ValidationError as exc: | ||||
|             raise EntryInvalidError(f"Serializer errors {serializer.errors}") from exc | ||||
|             raise EntryInvalidError( | ||||
|                 f"Serializer errors {serializer.errors}", serializer_errors=serializer.errors | ||||
|             ) from exc | ||||
|         return serializer | ||||
|  | ||||
|     def apply(self) -> bool: | ||||
| @ -204,7 +219,7 @@ class Importer: | ||||
|         for entry in self.__import.entries: | ||||
|             model_app_label, model_name = entry.model.split(".") | ||||
|             try: | ||||
|                 model: SerializerModel = apps.get_model(model_app_label, model_name) | ||||
|                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||
|             except LookupError: | ||||
|                 self.logger.warning( | ||||
|                     "app or model does not exist", app=model_app_label, model=model_name | ||||
| @ -214,14 +229,14 @@ class Importer: | ||||
|             try: | ||||
|                 serializer = self._validate_single(entry) | ||||
|             except EntryInvalidError as exc: | ||||
|                 self.logger.warning("entry invalid", entry=entry, error=exc) | ||||
|                 self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) | ||||
|                 return False | ||||
|  | ||||
|             model = serializer.save() | ||||
|             if "pk" in entry.identifiers: | ||||
|                 self.__pk_map[entry.identifiers["pk"]] = model.pk | ||||
|             entry._instance = model | ||||
|             self.logger.debug("updated model", model=model, pk=model.pk) | ||||
|             entry._state = BlueprintEntryState(model) | ||||
|             self.logger.debug("updated model", model=model) | ||||
|         return True | ||||
|  | ||||
|     def validate(self) -> tuple[bool, list[EventDict]]: | ||||
| @ -230,7 +245,7 @@ class Importer: | ||||
|         self.logger.debug("Starting blueprint import validation") | ||||
|         orig_import = deepcopy(self.__import) | ||||
|         if self.__import.version != 1: | ||||
|             self.logger.warning("Invalid bundle version") | ||||
|             self.logger.warning("Invalid blueprint version") | ||||
|             return False, [] | ||||
|         with ( | ||||
|             transaction_rollback(), | ||||
| @ -238,8 +253,8 @@ class Importer: | ||||
|         ): | ||||
|             successful = self._apply_models() | ||||
|             if not successful: | ||||
|                 self.logger.debug("blueprint validation failed") | ||||
|                 self.logger.debug("Blueprint validation failed") | ||||
|         for log in logs: | ||||
|             self.logger.debug(**log) | ||||
|             getattr(self.logger, log.get("log_level"))(**log) | ||||
|         self.__import = orig_import | ||||
|         return successful, logs | ||||
|  | ||||
| @ -2,3 +2,4 @@ | ||||
|  | ||||
| LABEL_AUTHENTIK_SYSTEM = "blueprints.goauthentik.io/system" | ||||
| LABEL_AUTHENTIK_INSTANTIATE = "blueprints.goauthentik.io/instantiate" | ||||
| LABEL_AUTHENTIK_GENERATED = "blueprints.goauthentik.io/generated" | ||||
|  | ||||
							
								
								
									
										0
									
								
								authentik/blueprints/v1/meta/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								authentik/blueprints/v1/meta/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										60
									
								
								authentik/blueprints/v1/meta/apply_blueprint.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								authentik/blueprints/v1/meta/apply_blueprint.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | ||||
| """Apply Blueprint meta model""" | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import BooleanField, JSONField | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.blueprints.v1.meta.registry import BaseMetaModel, MetaResult, registry | ||||
| from authentik.core.api.utils import PassiveSerializer, is_dict | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from authentik.blueprints.models import BlueprintInstance | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class ApplyBlueprintMetaSerializer(PassiveSerializer): | ||||
|     """Serializer for meta apply blueprint model""" | ||||
|  | ||||
|     identifiers = JSONField(validators=[is_dict]) | ||||
|     required = BooleanField(default=True) | ||||
|  | ||||
|     # We cannot override `instance` as that will confuse rest_framework | ||||
|     # and make it attempt to update the instance | ||||
|     blueprint_instance: "BlueprintInstance" | ||||
|  | ||||
|     def validate(self, attrs): | ||||
|         from authentik.blueprints.models import BlueprintInstance | ||||
|  | ||||
|         identifiers = attrs["identifiers"] | ||||
|         required = attrs["required"] | ||||
|         instance = BlueprintInstance.objects.filter(**identifiers).first() | ||||
|         if not instance and required: | ||||
|             raise ValidationError("Required blueprint does not exist") | ||||
|         self.blueprint_instance = instance | ||||
|         return super().validate(attrs) | ||||
|  | ||||
|     def create(self, validated_data: dict) -> MetaResult: | ||||
|         from authentik.blueprints.v1.tasks import apply_blueprint | ||||
|  | ||||
|         if not self.blueprint_instance: | ||||
|             LOGGER.info("Blueprint does not exist, but not required") | ||||
|             return MetaResult() | ||||
|         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) | ||||
|         # pylint: disable=no-value-for-parameter | ||||
|         apply_blueprint(str(self.blueprint_instance.pk)) | ||||
|         return MetaResult() | ||||
|  | ||||
|  | ||||
| @registry.register("metaapplyblueprint") | ||||
| class MetaApplyBlueprint(BaseMetaModel): | ||||
|     """Meta model to apply another blueprint""" | ||||
|  | ||||
|     @staticmethod | ||||
|     def serializer() -> ApplyBlueprintMetaSerializer: | ||||
|         return ApplyBlueprintMetaSerializer | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|         abstract = True | ||||
							
								
								
									
										61
									
								
								authentik/blueprints/v1/meta/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								authentik/blueprints/v1/meta/registry.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,61 @@ | ||||
| """Base models""" | ||||
| from django.apps import apps | ||||
| from django.db.models import Model | ||||
| from rest_framework.serializers import Serializer | ||||
|  | ||||
|  | ||||
| class BaseMetaModel(Model): | ||||
|     """Base models""" | ||||
|  | ||||
|     @staticmethod | ||||
|     def serializer() -> Serializer: | ||||
|         """Serializer similar to SerializerModel, but as a static method since | ||||
|         this is an abstract model""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
|         abstract = True | ||||
|  | ||||
|  | ||||
| class MetaResult: | ||||
|     """Result returned by Meta Models' serializers. Empty class but we can't return none as | ||||
|     the framework doesn't allow that""" | ||||
|  | ||||
|  | ||||
| class MetaModelRegistry: | ||||
|     """Registry for pseudo meta models""" | ||||
|  | ||||
|     models: dict[str, BaseMetaModel] | ||||
|     virtual_prefix: str | ||||
|  | ||||
|     def __init__(self, prefix: str) -> None: | ||||
|         self.models = {} | ||||
|         self.virtual_prefix = prefix | ||||
|  | ||||
|     def register(self, model_id: str): | ||||
|         """Register model class under `model_id`""" | ||||
|  | ||||
|         def inner_wrapper(cls): | ||||
|             self.models[model_id] = cls | ||||
|             return cls | ||||
|  | ||||
|         return inner_wrapper | ||||
|  | ||||
|     def get_models(self): | ||||
|         """Wrapper for django's `get_models` to list all models""" | ||||
|         models = apps.get_models() | ||||
|         for _, value in self.models.items(): | ||||
|             models.append(value) | ||||
|         return models | ||||
|  | ||||
|     def get_model(self, app_label: str, model_id: str) -> type[Model]: | ||||
|         """Get model checks if any virtual models are registered, and falls back | ||||
|         to actual django models""" | ||||
|         if app_label.lower() == self.virtual_prefix: | ||||
|             if model_id.lower() in self.models: | ||||
|                 return self.models[model_id] | ||||
|         return apps.get_model(app_label, model_id) | ||||
|  | ||||
|  | ||||
| registry = MetaModelRegistry("authentik_blueprints") | ||||
| @ -4,15 +4,21 @@ from hashlib import sha512 | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
|  | ||||
| from dacite import from_dict | ||||
| from dacite.core import from_dict | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| from django.utils.text import slugify | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from structlog.stdlib import get_logger | ||||
| from yaml import load | ||||
| from yaml.error import YAMLError | ||||
|  | ||||
| from authentik.blueprints.models import BlueprintInstance, BlueprintInstanceStatus | ||||
| from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata | ||||
| from authentik.blueprints.models import ( | ||||
|     BlueprintInstance, | ||||
|     BlueprintInstanceStatus, | ||||
|     BlueprintRetrievalFailed, | ||||
| ) | ||||
| from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, EntryInvalidError | ||||
| from authentik.blueprints.v1.importer import Importer | ||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | ||||
| from authentik.events.monitored_tasks import ( | ||||
| @ -21,9 +27,12 @@ from authentik.events.monitored_tasks import ( | ||||
|     TaskResultStatus, | ||||
|     prefill_task, | ||||
| ) | ||||
| from authentik.events.utils import sanitize_dict | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class BlueprintFile: | ||||
| @ -39,27 +48,45 @@ class BlueprintFile: | ||||
| @CELERY_APP.task( | ||||
|     throws=(DatabaseError, ProgrammingError, InternalError), | ||||
| ) | ||||
| def blueprints_find_dict(): | ||||
|     """Find blueprints as `blueprints_find` does, but return a safe dict""" | ||||
|     blueprints = [] | ||||
|     for blueprint in blueprints_find(): | ||||
|         blueprints.append(sanitize_dict(asdict(blueprint))) | ||||
|     return blueprints | ||||
|  | ||||
|  | ||||
| def blueprints_find(): | ||||
|     """Find blueprints and return valid ones""" | ||||
|     blueprints = [] | ||||
|     root = Path(CONFIG.y("blueprints_dir")) | ||||
|     for file in root.glob("**/*.yaml"): | ||||
|         path = Path(file) | ||||
|         LOGGER.debug("found blueprint", path=str(path)) | ||||
|         with open(path, "r", encoding="utf-8") as blueprint_file: | ||||
|             try: | ||||
|                 raw_blueprint = load(blueprint_file.read(), BlueprintLoader) | ||||
|             except YAMLError: | ||||
|             except YAMLError as exc: | ||||
|                 raw_blueprint = None | ||||
|                 LOGGER.warning("failed to parse blueprint", exc=exc, path=str(path)) | ||||
|             if not raw_blueprint: | ||||
|                 continue | ||||
|             metadata = raw_blueprint.get("metadata", None) | ||||
|             version = raw_blueprint.get("version", 1) | ||||
|             if version != 1: | ||||
|                 LOGGER.warning("invalid blueprint version", version=version, path=str(path)) | ||||
|                 continue | ||||
|         file_hash = sha512(path.read_bytes()).hexdigest() | ||||
|         blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime) | ||||
|         blueprint = BlueprintFile( | ||||
|             str(path.relative_to(root)), version, file_hash, int(path.stat().st_mtime) | ||||
|         ) | ||||
|         blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None | ||||
|         blueprints.append(blueprint) | ||||
|         LOGGER.info( | ||||
|             "parsed & loaded blueprint", | ||||
|             hash=file_hash, | ||||
|             path=str(path), | ||||
|         ) | ||||
|     return blueprints | ||||
|  | ||||
|  | ||||
| @ -101,9 +128,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | ||||
|         ) | ||||
|         instance.save() | ||||
|     if instance.last_applied_hash != blueprint.hash: | ||||
|         instance.metadata = asdict(blueprint.meta) if blueprint.meta else {} | ||||
|         instance.save() | ||||
|         apply_blueprint.delay(instance.pk.hex) | ||||
|         apply_blueprint.delay(str(instance.pk)) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
| @ -112,16 +137,18 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | ||||
| ) | ||||
| def apply_blueprint(self: MonitoredTask, instance_pk: str): | ||||
|     """Apply single blueprint""" | ||||
|     self.set_uid(instance_pk) | ||||
|     self.save_on_success = False | ||||
|     instance: Optional[BlueprintInstance] = None | ||||
|     try: | ||||
|         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() | ||||
|         self.set_uid(slugify(instance.name)) | ||||
|         if not instance or not instance.enabled: | ||||
|             return | ||||
|         full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(instance.path)) | ||||
|         file_hash = sha512(full_path.read_bytes()).hexdigest() | ||||
|         with open(full_path, "r", encoding="utf-8") as blueprint_file: | ||||
|             importer = Importer(blueprint_file.read(), instance.context) | ||||
|         blueprint_content = instance.retrieve() | ||||
|         file_hash = sha512(blueprint_content.encode()).hexdigest() | ||||
|         importer = Importer(blueprint_content, instance.context) | ||||
|         if importer.blueprint.metadata: | ||||
|             instance.metadata = asdict(importer.blueprint.metadata) | ||||
|         valid, logs = importer.validate() | ||||
|         if not valid: | ||||
|             instance.status = BlueprintInstanceStatus.ERROR | ||||
| @ -137,9 +164,18 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str): | ||||
|         instance.status = BlueprintInstanceStatus.SUCCESSFUL | ||||
|         instance.last_applied_hash = file_hash | ||||
|         instance.last_applied = now() | ||||
|         instance.save() | ||||
|         self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) | ||||
|     except (DatabaseError, ProgrammingError, InternalError, IOError) as exc: | ||||
|         instance.status = BlueprintInstanceStatus.ERROR | ||||
|         instance.save() | ||||
|     except ( | ||||
|         DatabaseError, | ||||
|         ProgrammingError, | ||||
|         InternalError, | ||||
|         IOError, | ||||
|         BlueprintRetrievalFailed, | ||||
|         EntryInvalidError, | ||||
|     ) as exc: | ||||
|         if instance: | ||||
|             instance.status = BlueprintInstanceStatus.ERROR | ||||
|         self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc)) | ||||
|     finally: | ||||
|         if instance: | ||||
|             instance.save() | ||||
|  | ||||
| @ -50,7 +50,9 @@ class ApplicationSerializer(ModelSerializer): | ||||
|  | ||||
|     def get_launch_url(self, app: Application) -> Optional[str]: | ||||
|         """Allow formatting of launch URL""" | ||||
|         user = self.context["request"].user | ||||
|         user = None | ||||
|         if "request" in self.context: | ||||
|             user = self.context["request"].user | ||||
|         return app.get_launch_url(user) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -17,7 +17,7 @@ from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.api import ManagedSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | ||||
| from authentik.core.expression import PropertyMappingEvaluator | ||||
| from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||
| from authentik.core.models import PropertyMapping | ||||
| from authentik.lib.utils.reflection import all_subclasses | ||||
| from authentik.policies.api.exec import PolicyTestSerializer | ||||
| @ -41,7 +41,9 @@ class PropertyMappingSerializer(ManagedSerializer, ModelSerializer, MetaNameSeri | ||||
|  | ||||
|     def validate_expression(self, expression: str) -> str: | ||||
|         """Test Syntax""" | ||||
|         evaluator = PropertyMappingEvaluator() | ||||
|         evaluator = PropertyMappingEvaluator( | ||||
|             self.instance, | ||||
|         ) | ||||
|         evaluator.validate(expression) | ||||
|         return expression | ||||
|  | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """authentik core app config""" | ||||
| from django.conf import settings | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
|  | ||||
| class AuthentikCoreConfig(ManagedAppConfig): | ||||
|  | ||||
							
								
								
									
										0
									
								
								authentik/core/expression/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								authentik/core/expression/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -2,28 +2,33 @@ | ||||
| from traceback import format_tb | ||||
| from typing import Optional | ||||
| 
 | ||||
| from django.db.models import Model | ||||
| from django.http import HttpRequest | ||||
| from guardian.utils import get_anonymous_user | ||||
| 
 | ||||
| from authentik.core.models import PropertyMapping, User | ||||
| from authentik.core.models import User | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.expression.evaluator import BaseEvaluator | ||||
| from authentik.policies.types import PolicyRequest | ||||
| 
 | ||||
| 
 | ||||
| class PropertyMappingEvaluator(BaseEvaluator): | ||||
|     """Custom Evalautor that adds some different context variables.""" | ||||
|     """Custom Evaluator that adds some different context variables.""" | ||||
| 
 | ||||
|     def set_context( | ||||
|     def __init__( | ||||
|         self, | ||||
|         user: Optional[User], | ||||
|         request: Optional[HttpRequest], | ||||
|         mapping: PropertyMapping, | ||||
|         model: Model, | ||||
|         user: Optional[User] = None, | ||||
|         request: Optional[HttpRequest] = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """Update context with context from PropertyMapping's evaluate""" | ||||
|         if hasattr(model, "name"): | ||||
|             _filename = model.name | ||||
|         else: | ||||
|             _filename = str(model) | ||||
|         super().__init__(filename=_filename) | ||||
|         req = PolicyRequest(user=get_anonymous_user()) | ||||
|         req.obj = mapping | ||||
|         req.obj = model | ||||
|         if user: | ||||
|             req.user = user | ||||
|             self._context["user"] = user | ||||
| @ -7,9 +7,9 @@ from django.core.management.base import BaseCommand | ||||
| from django.db.models import Model | ||||
| from django.db.models.signals import post_save, pre_delete | ||||
|  | ||||
| from authentik import __version__ | ||||
| from authentik import get_full_version | ||||
| from authentik.core.models import User | ||||
| from authentik.events.middleware import IGNORED_MODELS | ||||
| from authentik.events.middleware import should_log_model | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.utils import model_to_dict | ||||
|  | ||||
| @ -18,7 +18,7 @@ BANNER_TEXT = """### authentik shell ({authentik}) | ||||
|     node=platform.node(), | ||||
|     python=platform.python_version(), | ||||
|     arch=platform.machine(), | ||||
|     authentik=__version__, | ||||
|     authentik=get_full_version(), | ||||
| ) | ||||
|  | ||||
|  | ||||
| @ -50,7 +50,7 @@ class Command(BaseCommand): | ||||
|     # pylint: disable=unused-argument | ||||
|     def post_save_handler(sender, instance: Model, created: bool, **_): | ||||
|         """Signal handler for all object's post_save""" | ||||
|         if isinstance(instance, IGNORED_MODELS): | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
|  | ||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||
| @ -66,7 +66,7 @@ class Command(BaseCommand): | ||||
|     # pylint: disable=unused-argument | ||||
|     def pre_delete_handler(sender, instance: Model, **_): | ||||
|         """Signal handler for all object's pre_delete""" | ||||
|         if isinstance(instance, IGNORED_MODELS):  # pragma: no cover | ||||
|         if not should_log_model(instance):  # pragma: no cover | ||||
|             return | ||||
|  | ||||
|         Event.new(EventAction.MODEL_DELETED, model=model_to_dict(instance)).set_user( | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """authentik admin Middleware to impersonate users""" | ||||
| from contextvars import ContextVar | ||||
| from typing import Callable | ||||
| from typing import Callable, Optional | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| @ -13,9 +13,9 @@ RESPONSE_HEADER_ID = "X-authentik-id" | ||||
| KEY_AUTH_VIA = "auth_via" | ||||
| KEY_USER = "user" | ||||
|  | ||||
| CTX_REQUEST_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "request_id", default=None) | ||||
| CTX_HOST = ContextVar(STRUCTLOG_KEY_PREFIX + "host", default=None) | ||||
| CTX_AUTH_VIA = ContextVar(STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | ||||
| CTX_REQUEST_ID = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "request_id", default=None) | ||||
| CTX_HOST = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "host", default=None) | ||||
| CTX_AUTH_VIA = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) | ||||
|  | ||||
|  | ||||
| class ImpersonateMiddleware: | ||||
|  | ||||
| @ -617,10 +617,9 @@ class PropertyMapping(SerializerModel, ManagedModel): | ||||
|  | ||||
|     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: | ||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||
|         from authentik.core.expression import PropertyMappingEvaluator | ||||
|         from authentik.core.expression.evaluator import PropertyMappingEvaluator | ||||
|  | ||||
|         evaluator = PropertyMappingEvaluator() | ||||
|         evaluator.set_context(user, request, self, **kwargs) | ||||
|         evaluator = PropertyMappingEvaluator(self, user, request, **kwargs) | ||||
|         try: | ||||
|             return evaluator.evaluate(self.expression) | ||||
|         except Exception as exc: | ||||
|  | ||||
| @ -1,31 +0,0 @@ | ||||
| {% extends 'base/skeleton.html' %} | ||||
|  | ||||
| {% load i18n %} | ||||
|  | ||||
| {% block head %} | ||||
| {{ block.super }} | ||||
| <style> | ||||
|     .pf-c-empty-state { | ||||
|         height: 100vh; | ||||
|     } | ||||
| </style> | ||||
| {% endblock %} | ||||
|  | ||||
| {% block body %} | ||||
| <section class="ak-static-page pf-c-page__main-section pf-m-no-padding-mobile pf-m-xl"> | ||||
|     <div class="pf-c-empty-state"> | ||||
|         <div class="pf-c-empty-state__content"> | ||||
|             <i class="fas fa-exclamation-circle pf-c-empty-state__icon" aria-hidden="true"></i> | ||||
|             <h1 class="pf-c-title pf-m-lg"> | ||||
|                 {% trans title %} | ||||
|             </h1> | ||||
|             <div class="pf-c-empty-state__body"> | ||||
|                 {% if message %} | ||||
|                 <h3>{% trans message %}</h3> | ||||
|                 {% endif %} | ||||
|             </div> | ||||
|             <a href="/" class="pf-c-button pf-m-primary pf-m-block">{% trans 'Go to home' %}</a> | ||||
|         </div> | ||||
|     </div> | ||||
| </section> | ||||
| {% endblock %} | ||||
| @ -4,14 +4,16 @@ | ||||
| {% load i18n %} | ||||
|  | ||||
| {% block head %} | ||||
| <script src="{% static 'dist/admin/AdminInterface.js' %}" type="module"></script> | ||||
| <script src="{% static 'dist/admin/AdminInterface.js' %}?version={{ version }}" type="module"></script> | ||||
| <meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)"> | ||||
| <meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)"> | ||||
| <link rel="icon" href="{{ tenant.branding_favicon }}"> | ||||
| <link rel="shortcut icon" href="{{ tenant.branding_favicon }}"> | ||||
| <script> | ||||
| window.authentik = {}; | ||||
| window.authentik.locale = "{{ tenant.default_locale }}"; | ||||
| window.authentik.config = JSON.parse('{{ config_json|safe }}'); | ||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|safe }}'); | ||||
| window.authentik.config = JSON.parse('{{ config_json|escapejs }}'); | ||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|escapejs }}'); | ||||
| </script> | ||||
| {% endblock %} | ||||
|  | ||||
|  | ||||
							
								
								
									
										21
									
								
								authentik/core/templates/if/error.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								authentik/core/templates/if/error.html
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| {% extends 'login/base_full.html' %} | ||||
|  | ||||
| {% load static %} | ||||
| {% load i18n %} | ||||
|  | ||||
| {% block title %} | ||||
| {% trans 'End session' %} - {{ tenant.branding_title }} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block card_title %} | ||||
| {% trans title %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block card %} | ||||
| <form method="POST" class="pf-c-form"> | ||||
|     <p>{% trans message %}</p> | ||||
|     <a id="ak-back-home" href="{% url 'authentik_core:root-redirect' %}" class="pf-c-button pf-m-primary"> | ||||
|         {% trans 'Go home' %} | ||||
|     </a> | ||||
| </form> | ||||
| {% endblock %} | ||||
| @ -6,14 +6,16 @@ | ||||
| {% block head_before %} | ||||
| {{ block.super }} | ||||
| <link rel="prefetch" href="{{ flow.background_url }}" /> | ||||
| <link rel="icon" href="{{ tenant.branding_favicon }}"> | ||||
| <link rel="shortcut icon" href="{{ tenant.branding_favicon }}"> | ||||
| {% if flow.compatibility_mode and not inspector %} | ||||
| <script>ShadyDOM = { force: !navigator.webdriver };</script> | ||||
| {% endif %} | ||||
| <script> | ||||
| window.authentik = {}; | ||||
| window.authentik.locale = "{{ tenant.default_locale }}"; | ||||
| window.authentik.config = JSON.parse( '{{ config_json|safe }}'); | ||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|safe }}'); | ||||
| window.authentik.config = JSON.parse('{{ config_json|escapejs }}'); | ||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|escapejs }}'); | ||||
| window.authentik.flow = { | ||||
|     "layout": "{{ flow.layout }}", | ||||
| }; | ||||
| @ -21,7 +23,7 @@ window.authentik.flow = { | ||||
| {% endblock %} | ||||
|  | ||||
| {% block head %} | ||||
| <script src="{% static 'dist/flow/FlowInterface.js' %}" type="module"></script> | ||||
| <script src="{% static 'dist/flow/FlowInterface.js' %}?version={{ version }}" type="module"></script> | ||||
| <style> | ||||
| :root { | ||||
|     --ak-flow-background: url("{{ flow.background_url }}"); | ||||
|  | ||||
| @ -4,14 +4,16 @@ | ||||
| {% load i18n %} | ||||
|  | ||||
| {% block head %} | ||||
| <script src="{% static 'dist/user/UserInterface.js' %}" type="module"></script> | ||||
| <script src="{% static 'dist/user/UserInterface.js' %}?version={{ version }}" type="module"></script> | ||||
| <meta name="theme-color" content="#151515" media="(prefers-color-scheme: light)"> | ||||
| <meta name="theme-color" content="#151515" media="(prefers-color-scheme: dark)"> | ||||
| <link rel="icon" href="{{ tenant.branding_favicon }}"> | ||||
| <link rel="shortcut icon" href="{{ tenant.branding_favicon }}"> | ||||
| <script> | ||||
| window.authentik = {}; | ||||
| window.authentik.locale = "{{ tenant.default_locale }}"; | ||||
| window.authentik.config = JSON.parse('{{ config_json|safe }}'); | ||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|safe }}'); | ||||
| window.authentik.config = JSON.parse('{{ config_json|escapejs }}'); | ||||
| window.authentik.tenant = JSON.parse('{{ tenant_json|escapejs }}'); | ||||
| </script> | ||||
| {% endblock %} | ||||
|  | ||||
|  | ||||
| @ -5,8 +5,7 @@ from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.flows.models import Flow | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
| from authentik.policies.models import PolicyBinding | ||||
| from authentik.providers.oauth2.models import OAuth2Provider | ||||
| @ -20,10 +19,7 @@ class TestApplicationsAPI(APITestCase): | ||||
|         self.provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             redirect_uris="http://some-other-domain", | ||||
|             authorization_flow=Flow.objects.create( | ||||
|                 name="test", | ||||
|                 slug="test", | ||||
|             ), | ||||
|             authorization_flow=create_test_flow(), | ||||
|         ) | ||||
|         self.allowed = Application.objects.create( | ||||
|             name="allowed", | ||||
|  | ||||
| @ -4,8 +4,7 @@ from unittest.mock import MagicMock, patch | ||||
| from django.urls import reverse | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_tenant | ||||
| from authentik.flows.models import Flow, FlowDesignation | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant | ||||
| from authentik.flows.tests import FlowTestCase | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
| @ -21,11 +20,7 @@ class TestApplicationsViews(FlowTestCase): | ||||
|  | ||||
|     def test_check_redirect(self): | ||||
|         """Test redirect""" | ||||
|         empty_flow = Flow.objects.create( | ||||
|             name="foo", | ||||
|             slug="foo", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         empty_flow = create_test_flow() | ||||
|         tenant: Tenant = create_test_tenant() | ||||
|         tenant.flow_authentication = empty_flow | ||||
|         tenant.save() | ||||
| @ -49,11 +44,7 @@ class TestApplicationsViews(FlowTestCase): | ||||
|     def test_check_redirect_auth(self): | ||||
|         """Test redirect""" | ||||
|         self.client.force_login(self.user) | ||||
|         empty_flow = Flow.objects.create( | ||||
|             name="foo", | ||||
|             slug="foo", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         empty_flow = create_test_flow() | ||||
|         tenant: Tenant = create_test_tenant() | ||||
|         tenant.flow_authentication = empty_flow | ||||
|         tenant.save() | ||||
|  | ||||
| @ -6,7 +6,7 @@ from guardian.utils import get_anonymous_user | ||||
|  | ||||
| from authentik.core.models import SourceUserMatchingModes, User | ||||
| from authentik.core.sources.flow_manager import Action | ||||
| from authentik.flows.models import Flow, FlowDesignation | ||||
| from authentik.core.tests.utils import create_test_flow | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.lib.tests.utils import get_request | ||||
| from authentik.policies.denied import AccessDeniedResponse | ||||
| @ -152,9 +152,7 @@ class TestSourceFlowManager(TestCase): | ||||
|         """Test error handling when a source selected flow is non-applicable due to a policy""" | ||||
|         self.source.user_matching_mode = SourceUserMatchingModes.USERNAME_LINK | ||||
|  | ||||
|         flow = Flow.objects.create( | ||||
|             name="test", slug="test", title="test", designation=FlowDesignation.ENROLLMENT | ||||
|         ) | ||||
|         flow = create_test_flow() | ||||
|         policy = ExpressionPolicy.objects.create( | ||||
|             name="false", expression="""ak_message("foo");return False""" | ||||
|         ) | ||||
|  | ||||
| @ -159,7 +159,6 @@ class TestUsersAPI(APITestCase): | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:user-paths"), | ||||
|         ) | ||||
|         print(response.content) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertJSONEqual(response.content.decode(), {"paths": ["users"]}) | ||||
|  | ||||
|  | ||||
| @ -52,5 +52,5 @@ def create_test_cert() -> CertificateKeyPair: | ||||
|         subject_alt_names=["goauthentik.io"], | ||||
|         validity_days=360, | ||||
|     ) | ||||
|     builder.name = generate_id() | ||||
|     builder.common_name = generate_id() | ||||
|     return builder.save() | ||||
|  | ||||
| @ -32,7 +32,7 @@ class BadRequestView(TemplateView): | ||||
|     extra_context = {"title": "Bad Request"} | ||||
|  | ||||
|     response_class = BadRequestTemplateResponse | ||||
|     template_name = "error/generic.html" | ||||
|     template_name = "if/error.html" | ||||
|  | ||||
|  | ||||
| class ForbiddenView(TemplateView): | ||||
| @ -41,7 +41,7 @@ class ForbiddenView(TemplateView): | ||||
|     extra_context = {"title": "Forbidden"} | ||||
|  | ||||
|     response_class = ForbiddenTemplateResponse | ||||
|     template_name = "error/generic.html" | ||||
|     template_name = "if/error.html" | ||||
|  | ||||
|  | ||||
| class NotFoundView(TemplateView): | ||||
| @ -50,7 +50,7 @@ class NotFoundView(TemplateView): | ||||
|     extra_context = {"title": "Not Found"} | ||||
|  | ||||
|     response_class = NotFoundTemplateResponse | ||||
|     template_name = "error/generic.html" | ||||
|     template_name = "if/error.html" | ||||
|  | ||||
|  | ||||
| class ServerErrorView(TemplateView): | ||||
| @ -59,7 +59,7 @@ class ServerErrorView(TemplateView): | ||||
|     extra_context = {"title": "Server Error"} | ||||
|  | ||||
|     response_class = ServerErrorTemplateResponse | ||||
|     template_name = "error/generic.html" | ||||
|     template_name = "if/error.html" | ||||
|  | ||||
|     # pylint: disable=useless-super-delegation | ||||
|     def dispatch(self, *args, **kwargs):  # pragma: no cover | ||||
|  | ||||
| @ -12,10 +12,11 @@ from django_filters.filters import BooleanFilter | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer, ValidationError | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| from datetime import datetime | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|  | ||||
| @ -26,7 +26,7 @@ class CertificateBuilder: | ||||
|         self.common_name = "authentik Self-signed Certificate" | ||||
|         self.cert = CertificateKeyPair() | ||||
|  | ||||
|     def save(self) -> Optional[CertificateKeyPair]: | ||||
|     def save(self) -> CertificateKeyPair: | ||||
|         """Save generated certificate as model""" | ||||
|         if not self.__certificate: | ||||
|             raise ValueError("Certificated hasn't been built yet") | ||||
|  | ||||
							
								
								
									
										51
									
								
								authentik/crypto/management/commands/import_certificate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								authentik/crypto/management/commands/import_certificate.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,51 @@ | ||||
| """Import certificate""" | ||||
| from sys import exit as sys_exit | ||||
|  | ||||
| from django.core.management.base import BaseCommand, no_translations | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.crypto.api import CertificateKeyPairSerializer | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class Command(BaseCommand): | ||||
|     """Import certificate""" | ||||
|  | ||||
|     @no_translations | ||||
|     def handle(self, *args, **options): | ||||
|         """Import certificate""" | ||||
|         keypair = CertificateKeyPair.objects.filter(name=options["name"]).first() | ||||
|         dirty = False | ||||
|         if not keypair: | ||||
|             keypair = CertificateKeyPair(name=options["name"]) | ||||
|             dirty = True | ||||
|         with open(options["certificate"], mode="r", encoding="utf-8") as _cert: | ||||
|             cert_data = _cert.read() | ||||
|             if keypair.certificate_data != cert_data: | ||||
|                 dirty = True | ||||
|             keypair.certificate_data = cert_data | ||||
|         if options["private_key"]: | ||||
|             with open(options["private_key"], mode="r", encoding="utf-8") as _key: | ||||
|                 key_data = _key.read() | ||||
|                 if keypair.key_data != key_data: | ||||
|                     dirty = True | ||||
|                 keypair.key_data = key_data | ||||
|         # Validate that cert and key are actually PEM and valid | ||||
|         serializer = CertificateKeyPairSerializer(instance=keypair) | ||||
|         try: | ||||
|             serializer.validate_certificate_data(keypair.certificate_data) | ||||
|             if keypair.key_data != "": | ||||
|                 serializer.validate_certificate_data(keypair.key_data) | ||||
|         except ValidationError as exc: | ||||
|             self.stderr.write(exc) | ||||
|             sys_exit(1) | ||||
|         if dirty: | ||||
|             keypair.save() | ||||
|  | ||||
|     def add_arguments(self, parser): | ||||
|         parser.add_argument("--certificate", type=str, required=True) | ||||
|         parser.add_argument("--private-key", type=str, required=False) | ||||
|         parser.add_argument("--name", type=str, required=True) | ||||
| @ -6,12 +6,7 @@ from uuid import uuid4 | ||||
|  | ||||
| from cryptography.hazmat.backends import default_backend | ||||
| from cryptography.hazmat.primitives import hashes | ||||
| from cryptography.hazmat.primitives.asymmetric.ec import ( | ||||
|     EllipticCurvePrivateKey, | ||||
|     EllipticCurvePublicKey, | ||||
| ) | ||||
| from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey | ||||
| from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey | ||||
| from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES, PUBLIC_KEY_TYPES | ||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||
| from cryptography.x509 import Certificate, load_pem_x509_certificate | ||||
| from django.db import models | ||||
| @ -42,8 +37,8 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|     ) | ||||
|  | ||||
|     _cert: Optional[Certificate] = None | ||||
|     _private_key: Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey] = None | ||||
|     _public_key: Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey] = None | ||||
|     _private_key: Optional[PRIVATE_KEY_TYPES] = None | ||||
|     _public_key: Optional[PUBLIC_KEY_TYPES] = None | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> Serializer: | ||||
| @ -61,7 +56,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|         return self._cert | ||||
|  | ||||
|     @property | ||||
|     def public_key(self) -> Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey]: | ||||
|     def public_key(self) -> Optional[PUBLIC_KEY_TYPES]: | ||||
|         """Get public key of the private key""" | ||||
|         if not self._public_key: | ||||
|             self._public_key = self.private_key.public_key() | ||||
| @ -70,7 +65,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|     @property | ||||
|     def private_key( | ||||
|         self, | ||||
|     ) -> Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey]: | ||||
|     ) -> Optional[PRIVATE_KEY_TYPES]: | ||||
|         """Get python cryptography PrivateKey instance""" | ||||
|         if not self._private_key and self.key_data != "": | ||||
|             try: | ||||
|  | ||||
| @ -85,16 +85,18 @@ class NotificationTransportViewSet(UsedByMixin, ModelViewSet): | ||||
|         """Send example notification using selected transport. Requires | ||||
|         Modify permissions.""" | ||||
|         transport: NotificationTransport = self.get_object() | ||||
|         event = Event.new( | ||||
|             action="notification_test", | ||||
|             user=get_user(request.user), | ||||
|             app=self.__class__.__module__, | ||||
|             context={"foo": "bar"}, | ||||
|         ) | ||||
|         event.save() | ||||
|         notification = Notification( | ||||
|             severity=NotificationSeverity.NOTICE, | ||||
|             body=f"Test Notification from transport {transport.name}", | ||||
|             user=request.user, | ||||
|             event=Event( | ||||
|                 action="Test", | ||||
|                 user=get_user(request.user), | ||||
|                 app=self.__class__.__module__, | ||||
|                 context={"foo": "bar"}, | ||||
|             ), | ||||
|             event=event, | ||||
|         ) | ||||
|         try: | ||||
|             response = NotificationTransportTestSerializer( | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """authentik events app""" | ||||
| from prometheus_client import Gauge | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
| GAUGE_TASKS = Gauge( | ||||
|     "authentik_system_tasks", | ||||
|  | ||||
| @ -19,7 +19,7 @@ from authentik.flows.models import FlowToken | ||||
| from authentik.lib.sentry import before_send | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
|  | ||||
| IGNORED_MODELS = [ | ||||
| IGNORED_MODELS = ( | ||||
|     Event, | ||||
|     Notification, | ||||
|     UserObjectPermission, | ||||
| @ -27,12 +27,14 @@ IGNORED_MODELS = [ | ||||
|     StaticToken, | ||||
|     Session, | ||||
|     FlowToken, | ||||
| ] | ||||
| if settings.DEBUG: | ||||
|     from silk.models import Request, Response, SQLQuery | ||||
| ) | ||||
|  | ||||
|     IGNORED_MODELS += [Request, Response, SQLQuery] | ||||
| IGNORED_MODELS = tuple(IGNORED_MODELS) | ||||
|  | ||||
| def should_log_model(model: Model) -> bool: | ||||
|     """Return true if operation on `model` should be logged""" | ||||
|     if model.__module__.startswith("silk"): | ||||
|         return False | ||||
|     return not isinstance(model, IGNORED_MODELS) | ||||
|  | ||||
|  | ||||
| class AuditMiddleware: | ||||
| @ -109,7 +111,7 @@ class AuditMiddleware: | ||||
|         user: User, request: HttpRequest, sender, instance: Model, created: bool, **_ | ||||
|     ): | ||||
|         """Signal handler for all object's post_save""" | ||||
|         if isinstance(instance, IGNORED_MODELS): | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
|  | ||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||
| @ -119,7 +121,7 @@ class AuditMiddleware: | ||||
|     # pylint: disable=unused-argument | ||||
|     def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_): | ||||
|         """Signal handler for all object's pre_delete""" | ||||
|         if isinstance(instance, IGNORED_MODELS):  # pragma: no cover | ||||
|         if not should_log_model(instance):  # pragma: no cover | ||||
|             return | ||||
|  | ||||
|         EventNewThread( | ||||
|  | ||||
| @ -28,126 +28,6 @@ def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|         event.save() | ||||
|  | ||||
|  | ||||
| def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     Group = apps.get_model("authentik_core", "Group") | ||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||
|     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") | ||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") | ||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") | ||||
|  | ||||
|     admin_group = ( | ||||
|         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() | ||||
|     ) | ||||
|  | ||||
|     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( | ||||
|         name="default-match-configuration-error", | ||||
|         defaults={"action": EventAction.CONFIGURATION_ERROR}, | ||||
|     ) | ||||
|     trigger, _ = NotificationRule.objects.using(db_alias).update_or_create( | ||||
|         name="default-notify-configuration-error", | ||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, | ||||
|     ) | ||||
|     trigger.transports.set( | ||||
|         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") | ||||
|     ) | ||||
|     trigger.save() | ||||
|     PolicyBinding.objects.using(db_alias).update_or_create( | ||||
|         target=trigger, | ||||
|         policy=policy, | ||||
|         defaults={ | ||||
|             "order": 0, | ||||
|         }, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     Group = apps.get_model("authentik_core", "Group") | ||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||
|     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") | ||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") | ||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") | ||||
|  | ||||
|     admin_group = ( | ||||
|         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() | ||||
|     ) | ||||
|  | ||||
|     policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( | ||||
|         name="default-match-update", | ||||
|         defaults={"action": EventAction.UPDATE_AVAILABLE}, | ||||
|     ) | ||||
|     trigger, _ = NotificationRule.objects.using(db_alias).update_or_create( | ||||
|         name="default-notify-update", | ||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, | ||||
|     ) | ||||
|     trigger.transports.set( | ||||
|         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") | ||||
|     ) | ||||
|     trigger.save() | ||||
|     PolicyBinding.objects.using(db_alias).update_or_create( | ||||
|         target=trigger, | ||||
|         policy=policy, | ||||
|         defaults={ | ||||
|             "order": 0, | ||||
|         }, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     Group = apps.get_model("authentik_core", "Group") | ||||
|     PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") | ||||
|     EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") | ||||
|     NotificationRule = apps.get_model("authentik_events", "NotificationRule") | ||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") | ||||
|  | ||||
|     admin_group = ( | ||||
|         Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() | ||||
|     ) | ||||
|  | ||||
|     policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( | ||||
|         name="default-match-policy-exception", | ||||
|         defaults={"action": EventAction.POLICY_EXCEPTION}, | ||||
|     ) | ||||
|     policy_pm_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( | ||||
|         name="default-match-property-mapping-exception", | ||||
|         defaults={"action": EventAction.PROPERTY_MAPPING_EXCEPTION}, | ||||
|     ) | ||||
|     trigger, _ = NotificationRule.objects.using(db_alias).update_or_create( | ||||
|         name="default-notify-exception", | ||||
|         defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, | ||||
|     ) | ||||
|     trigger.transports.set( | ||||
|         NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") | ||||
|     ) | ||||
|     trigger.save() | ||||
|     PolicyBinding.objects.using(db_alias).update_or_create( | ||||
|         target=trigger, | ||||
|         policy=policy_policy_exc, | ||||
|         defaults={ | ||||
|             "order": 0, | ||||
|         }, | ||||
|     ) | ||||
|     PolicyBinding.objects.using(db_alias).update_or_create( | ||||
|         target=trigger, | ||||
|         policy=policy_pm_exc, | ||||
|         defaults={ | ||||
|             "order": 1, | ||||
|         }, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def transport_email_global(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") | ||||
|  | ||||
|     NotificationTransport.objects.using(db_alias).update_or_create( | ||||
|         name="default-email-transport", | ||||
|         defaults={"mode": TransportMode.EMAIL}, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def token_view_to_secret_view(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     from authentik.events.models import EventAction | ||||
|  | ||||
| @ -432,18 +312,6 @@ class Migration(migrations.Migration): | ||||
|                 "verbose_name_plural": "Notifications", | ||||
|             }, | ||||
|         ), | ||||
|         migrations.RunPython( | ||||
|             code=transport_email_global, | ||||
|         ), | ||||
|         migrations.RunPython( | ||||
|             code=notify_configuration_error, | ||||
|         ), | ||||
|         migrations.RunPython( | ||||
|             code=notify_update, | ||||
|         ), | ||||
|         migrations.RunPython( | ||||
|             code=notify_exception, | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="notificationtransport", | ||||
|             name="send_once", | ||||
|  | ||||
| @ -22,14 +22,20 @@ from django.utils.translation import gettext as _ | ||||
| from requests import RequestException | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik import __version__ | ||||
| from authentik import get_full_version | ||||
| from authentik.core.middleware import ( | ||||
|     SESSION_KEY_IMPERSONATE_ORIGINAL_USER, | ||||
|     SESSION_KEY_IMPERSONATE_USER, | ||||
| ) | ||||
| from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | ||||
| from authentik.events.geo import GEOIP_READER | ||||
| from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict | ||||
| from authentik.events.utils import ( | ||||
|     cleanse_dict, | ||||
|     get_user, | ||||
|     model_to_dict, | ||||
|     sanitize_dict, | ||||
|     sanitize_item, | ||||
| ) | ||||
| from authentik.lib.models import DomainlessURLValidator, SerializerModel | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.lib.utils.http import get_client_ip, get_http_session | ||||
| @ -355,10 +361,12 @@ class NotificationTransport(SerializerModel): | ||||
|             "user_username": notification.user.username, | ||||
|         } | ||||
|         if self.webhook_mapping: | ||||
|             default_body = self.webhook_mapping.evaluate( | ||||
|                 user=notification.user, | ||||
|                 request=None, | ||||
|                 notification=notification, | ||||
|             default_body = sanitize_item( | ||||
|                 self.webhook_mapping.evaluate( | ||||
|                     user=notification.user, | ||||
|                     request=None, | ||||
|                     notification=notification, | ||||
|                 ) | ||||
|             ) | ||||
|         try: | ||||
|             response = get_http_session().post( | ||||
| @ -406,7 +414,7 @@ class NotificationTransport(SerializerModel): | ||||
|                     "title": notification.body, | ||||
|                     "color": "#fd4b2d", | ||||
|                     "fields": fields, | ||||
|                     "footer": f"authentik v{__version__}", | ||||
|                     "footer": f"authentik {get_full_version()}", | ||||
|                 } | ||||
|             ], | ||||
|         } | ||||
|  | ||||
| @ -134,26 +134,31 @@ class MonitoredTask(Task): | ||||
|  | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||
|         if self._result: | ||||
|             if not self._result.uid: | ||||
|                 self._result.uid = self._uid | ||||
|             if self.save_on_success: | ||||
|                 TaskInfo( | ||||
|                     task_name=self.__name__, | ||||
|                     task_description=self.__doc__, | ||||
|                     start_timestamp=self.start, | ||||
|                     finish_timestamp=default_timer(), | ||||
|                     finish_time=datetime.now(), | ||||
|                     result=self._result, | ||||
|                     task_call_module=self.__module__, | ||||
|                     task_call_func=self.__name__, | ||||
|                     task_call_args=args, | ||||
|                     task_call_kwargs=kwargs, | ||||
|                 ).save(self.result_timeout_hours) | ||||
|         return super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||
|         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||
|         if not self._result: | ||||
|             return | ||||
|         if not self._result.uid: | ||||
|             self._result.uid = self._uid | ||||
|         info = TaskInfo( | ||||
|             task_name=self.__name__, | ||||
|             task_description=self.__doc__, | ||||
|             start_timestamp=self.start, | ||||
|             finish_timestamp=default_timer(), | ||||
|             finish_time=datetime.now(), | ||||
|             result=self._result, | ||||
|             task_call_module=self.__module__, | ||||
|             task_call_func=self.__name__, | ||||
|             task_call_args=args, | ||||
|             task_call_kwargs=kwargs, | ||||
|         ) | ||||
|         if self._result.status == TaskResultStatus.SUCCESSFUL and not self.save_on_success: | ||||
|             info.delete() | ||||
|             return | ||||
|         info.save(self.result_timeout_hours) | ||||
|  | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||
|         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||
|         if not self._result: | ||||
|             self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)]) | ||||
|         if not self._result.uid: | ||||
| @ -174,7 +179,6 @@ class MonitoredTask(Task): | ||||
|             EventAction.SYSTEM_TASK_EXCEPTION, | ||||
|             message=(f"Task {self.__name__} encountered an error: {exception_to_string(exc)}"), | ||||
|         ).save() | ||||
|         return super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||
|  | ||||
|     def run(self, *args, **kwargs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
| @ -31,8 +31,8 @@ class TestEventsNotifications(TestCase): | ||||
|  | ||||
|     def test_trigger_empty(self): | ||||
|         """Test trigger without any policies attached""" | ||||
|         transport = NotificationTransport.objects.create(name="transport") | ||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|  | ||||
| @ -43,8 +43,8 @@ class TestEventsNotifications(TestCase): | ||||
|  | ||||
|     def test_trigger_single(self): | ||||
|         """Test simple transport triggering""" | ||||
|         transport = NotificationTransport.objects.create(name="transport") | ||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -59,7 +59,7 @@ class TestEventsNotifications(TestCase): | ||||
|  | ||||
|     def test_trigger_no_group(self): | ||||
|         """Test trigger without group""" | ||||
|         trigger = NotificationRule.objects.create(name="trigger") | ||||
|         trigger = NotificationRule.objects.create(name=generate_id()) | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||
|         ) | ||||
| @ -72,9 +72,9 @@ class TestEventsNotifications(TestCase): | ||||
|  | ||||
|     def test_policy_error_recursive(self): | ||||
|         """Test Policy error which would cause recursion""" | ||||
|         transport = NotificationTransport.objects.create(name="transport") | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -95,9 +95,9 @@ class TestEventsNotifications(TestCase): | ||||
|         self.group.users.add(user2) | ||||
|         self.group.save() | ||||
|  | ||||
|         transport = NotificationTransport.objects.create(name="transport", send_once=True) | ||||
|         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -118,10 +118,10 @@ class TestEventsNotifications(TestCase): | ||||
|         ) | ||||
|  | ||||
|         transport = NotificationTransport.objects.create( | ||||
|             name="transport", webhook_mapping=mapping, mode=TransportMode.LOCAL | ||||
|             name=generate_id(), webhook_mapping=mapping, mode=TransportMode.LOCAL | ||||
|         ) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name="trigger", group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||
|  | ||||
							
								
								
									
										131
									
								
								authentik/events/tests/test_transports.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								authentik/events/tests/test_transports.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,131 @@ | ||||
| """transport tests""" | ||||
| from unittest.mock import PropertyMock, patch | ||||
|  | ||||
| from django.core import mail | ||||
| from django.core.mail.backends.locmem import EmailBackend | ||||
| from django.test import TestCase | ||||
| from requests_mock import Mocker | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.events.models import ( | ||||
|     Event, | ||||
|     Notification, | ||||
|     NotificationSeverity, | ||||
|     NotificationTransport, | ||||
|     NotificationWebhookMapping, | ||||
|     TransportMode, | ||||
| ) | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| class TestEventTransports(TestCase): | ||||
|     """Test Event Transports""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.user = create_test_admin_user() | ||||
|         self.event = Event.new("foo", "testing", foo="bar,").set_user(self.user) | ||||
|         self.event.save() | ||||
|         self.notification = Notification.objects.create( | ||||
|             severity=NotificationSeverity.ALERT, | ||||
|             body="foo", | ||||
|             event=self.event, | ||||
|             user=self.user, | ||||
|         ) | ||||
|  | ||||
|     def test_transport_webhook(self): | ||||
|         """Test webhook transport""" | ||||
|         transport: NotificationTransport = NotificationTransport.objects.create( | ||||
|             name=generate_id(), | ||||
|             mode=TransportMode.WEBHOOK, | ||||
|             webhook_url="http://localhost:1234/test", | ||||
|         ) | ||||
|         with Mocker() as mocker: | ||||
|             mocker.post("http://localhost:1234/test") | ||||
|             transport.send(self.notification) | ||||
|             self.assertEqual(mocker.call_count, 1) | ||||
|             self.assertEqual(mocker.request_history[0].method, "POST") | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[0].body.decode(), | ||||
|                 { | ||||
|                     "body": "foo", | ||||
|                     "severity": "alert", | ||||
|                     "user_email": self.user.email, | ||||
|                     "user_username": self.user.username, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|     def test_transport_webhook_mapping(self): | ||||
|         """Test webhook transport with custom mapping""" | ||||
|         mapping = NotificationWebhookMapping.objects.create( | ||||
|             name=generate_id(), expression="return request.user" | ||||
|         ) | ||||
|         transport: NotificationTransport = NotificationTransport.objects.create( | ||||
|             name=generate_id(), | ||||
|             mode=TransportMode.WEBHOOK, | ||||
|             webhook_url="http://localhost:1234/test", | ||||
|             webhook_mapping=mapping, | ||||
|         ) | ||||
|         with Mocker() as mocker: | ||||
|             mocker.post("http://localhost:1234/test") | ||||
|             transport.send(self.notification) | ||||
|             self.assertEqual(mocker.call_count, 1) | ||||
|             self.assertEqual(mocker.request_history[0].method, "POST") | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[0].body.decode(), | ||||
|                 {"email": self.user.email, "pk": self.user.pk, "username": self.user.username}, | ||||
|             ) | ||||
|  | ||||
|     def test_transport_webhook_slack(self): | ||||
|         """Test webhook transport (slack)""" | ||||
|         transport: NotificationTransport = NotificationTransport.objects.create( | ||||
|             name=generate_id(), | ||||
|             mode=TransportMode.WEBHOOK_SLACK, | ||||
|             webhook_url="http://localhost:1234/test", | ||||
|         ) | ||||
|         with Mocker() as mocker: | ||||
|             mocker.post("http://localhost:1234/test") | ||||
|             transport.send(self.notification) | ||||
|             self.assertEqual(mocker.call_count, 1) | ||||
|             self.assertEqual(mocker.request_history[0].method, "POST") | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[0].body.decode(), | ||||
|                 { | ||||
|                     "username": "authentik", | ||||
|                     "icon_url": "https://goauthentik.io/img/icon.png", | ||||
|                     "attachments": [ | ||||
|                         { | ||||
|                             "author_name": "authentik", | ||||
|                             "author_link": "https://goauthentik.io", | ||||
|                             "author_icon": "https://goauthentik.io/img/icon.png", | ||||
|                             "title": "custom_foo", | ||||
|                             "color": "#fd4b2d", | ||||
|                             "fields": [ | ||||
|                                 {"title": "Severity", "value": "alert", "short": True}, | ||||
|                                 { | ||||
|                                     "title": "Dispatched for user", | ||||
|                                     "value": self.user.username, | ||||
|                                     "short": True, | ||||
|                                 }, | ||||
|                                 {"title": "foo", "value": "bar,"}, | ||||
|                             ], | ||||
|                             "footer": f"authentik {get_full_version()}", | ||||
|                         } | ||||
|                     ], | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|     def test_transport_email(self): | ||||
|         """Test email transport""" | ||||
|         transport: NotificationTransport = NotificationTransport.objects.create( | ||||
|             name=generate_id(), | ||||
|             mode=TransportMode.EMAIL, | ||||
|         ) | ||||
|         with patch( | ||||
|             "authentik.stages.email.models.EmailStage.backend_class", | ||||
|             PropertyMock(return_value=EmailBackend), | ||||
|         ): | ||||
|             transport.send(self.notification) | ||||
|             self.assertEqual(len(mail.outbox), 1) | ||||
|             self.assertEqual(mail.outbox[0].subject, "authentik Notification: custom_foo") | ||||
|             self.assertIn(self.notification.body, mail.outbox[0].alternatives[0][0]) | ||||
| @ -23,21 +23,31 @@ from authentik.policies.types import PolicyRequest | ||||
| ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I) | ||||
|  | ||||
|  | ||||
| def cleanse_item(key: str, value: Any) -> Any: | ||||
|     """Cleanse a single item""" | ||||
|     if isinstance(value, dict): | ||||
|         return cleanse_dict(value) | ||||
|     if isinstance(value, list): | ||||
|         for idx, item in enumerate(value): | ||||
|             value[idx] = cleanse_item(key, item) | ||||
|         return value | ||||
|     try: | ||||
|         if SafeExceptionReporterFilter.hidden_settings.search( | ||||
|             key | ||||
|         ) and not ALLOWED_SPECIAL_KEYS.search(key): | ||||
|             return SafeExceptionReporterFilter.cleansed_substitute | ||||
|     except TypeError:  # pragma: no cover | ||||
|         return value | ||||
|     return value | ||||
|  | ||||
|  | ||||
| def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]: | ||||
|     """Cleanse a dictionary, recursively""" | ||||
|     final_dict = {} | ||||
|     for key, value in source.items(): | ||||
|         try: | ||||
|             if SafeExceptionReporterFilter.hidden_settings.search( | ||||
|                 key | ||||
|             ) and not ALLOWED_SPECIAL_KEYS.search(key): | ||||
|                 final_dict[key] = SafeExceptionReporterFilter.cleansed_substitute | ||||
|             else: | ||||
|                 final_dict[key] = value | ||||
|         except TypeError:  # pragma: no cover | ||||
|             final_dict[key] = value | ||||
|         if isinstance(value, dict): | ||||
|             final_dict[key] = cleanse_dict(value) | ||||
|         new_value = cleanse_item(key, value) | ||||
|         if new_value is not ...: | ||||
|             final_dict[key] = new_value | ||||
|     return final_dict | ||||
|  | ||||
|  | ||||
| @ -70,6 +80,45 @@ def get_user(user: User, original_user: Optional[User] = None) -> dict[str, Any] | ||||
|     return user_data | ||||
|  | ||||
|  | ||||
| # pylint: disable=too-many-return-statements | ||||
| def sanitize_item(value: Any) -> Any: | ||||
|     """Sanitize a single item, ensure it is JSON parsable""" | ||||
|     if is_dataclass(value): | ||||
|         # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, | ||||
|         # and deepcopy doesn't work with HttpRequests (neither django nor rest_framework). | ||||
|         # Currently, the only dataclass that actually holds an http request is a PolicyRequest | ||||
|         if isinstance(value, PolicyRequest): | ||||
|             value.http_request = None | ||||
|         value = asdict(value) | ||||
|     if isinstance(value, dict): | ||||
|         return sanitize_dict(value) | ||||
|     if isinstance(value, list): | ||||
|         new_values = [] | ||||
|         for item in value: | ||||
|             new_value = sanitize_item(item) | ||||
|             if new_value: | ||||
|                 new_values.append(new_value) | ||||
|         return new_values | ||||
|     if isinstance(value, (User, AnonymousUser)): | ||||
|         return sanitize_dict(get_user(value)) | ||||
|     if isinstance(value, models.Model): | ||||
|         return sanitize_dict(model_to_dict(value)) | ||||
|     if isinstance(value, UUID): | ||||
|         return value.hex | ||||
|     if isinstance(value, (HttpRequest, WSGIRequest)): | ||||
|         return ... | ||||
|     if isinstance(value, City): | ||||
|         return GEOIP_READER.city_to_dict(value) | ||||
|     if isinstance(value, Path): | ||||
|         return str(value) | ||||
|     if isinstance(value, type): | ||||
|         return { | ||||
|             "type": value.__name__, | ||||
|             "module": value.__module__, | ||||
|         } | ||||
|     return value | ||||
|  | ||||
|  | ||||
| def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]: | ||||
|     """clean source of all Models that would interfere with the JSONField. | ||||
|     Models are replaced with a dictionary of { | ||||
| @ -79,32 +128,7 @@ def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]: | ||||
|     }""" | ||||
|     final_dict = {} | ||||
|     for key, value in source.items(): | ||||
|         if is_dataclass(value): | ||||
|             # Because asdict calls `copy.deepcopy(obj)` on everything that's not tuple/dict, | ||||
|             # and deepcopy doesn't work with HttpRequests (neither django nor rest_framework). | ||||
|             # Currently, the only dataclass that actually holds an http request is a PolicyRequest | ||||
|             if isinstance(value, PolicyRequest): | ||||
|                 value.http_request = None | ||||
|             value = asdict(value) | ||||
|         if isinstance(value, dict): | ||||
|             final_dict[key] = sanitize_dict(value) | ||||
|         elif isinstance(value, (User, AnonymousUser)): | ||||
|             final_dict[key] = sanitize_dict(get_user(value)) | ||||
|         elif isinstance(value, models.Model): | ||||
|             final_dict[key] = sanitize_dict(model_to_dict(value)) | ||||
|         elif isinstance(value, UUID): | ||||
|             final_dict[key] = value.hex | ||||
|         elif isinstance(value, (HttpRequest, WSGIRequest)): | ||||
|             continue | ||||
|         elif isinstance(value, City): | ||||
|             final_dict[key] = GEOIP_READER.city_to_dict(value) | ||||
|         elif isinstance(value, Path): | ||||
|             final_dict[key] = str(value) | ||||
|         elif isinstance(value, type): | ||||
|             final_dict[key] = { | ||||
|                 "type": value.__name__, | ||||
|                 "module": value.__module__, | ||||
|             } | ||||
|         else: | ||||
|             final_dict[key] = value | ||||
|         new_value = sanitize_item(value) | ||||
|         if new_value is not ...: | ||||
|             final_dict[key] = new_value | ||||
|     return final_dict | ||||
|  | ||||
| @ -1,26 +1,22 @@ | ||||
| """Flow API Views""" | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models import Model | ||||
| from django.http import HttpResponse | ||||
| from django.http.response import HttpResponseBadRequest | ||||
| from django.urls import reverse | ||||
| from django.utils.translation import gettext as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.fields import ReadOnlyField | ||||
| from rest_framework.parsers import MultiPartParser | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.serializers import CharField, ModelSerializer, Serializer, SerializerMethodField | ||||
| from rest_framework.serializers import ModelSerializer, SerializerMethodField | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.v1.exporter import Exporter | ||||
| from authentik.blueprints.v1.exporter import FlowExporter | ||||
| from authentik.blueprints.v1.importer import Importer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import ( | ||||
| @ -29,6 +25,7 @@ from authentik.core.api.utils import ( | ||||
|     FileUploadSerializer, | ||||
|     LinkSerializer, | ||||
| ) | ||||
| from authentik.flows.api.flows_diagram import FlowDiagram, FlowDiagramSerializer | ||||
| from authentik.flows.exceptions import FlowNonApplicableException | ||||
| from authentik.flows.models import Flow | ||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key | ||||
| @ -80,30 +77,6 @@ class FlowSerializer(ModelSerializer): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class FlowDiagramSerializer(Serializer): | ||||
|     """response of the flow's diagram action""" | ||||
|  | ||||
|     diagram = CharField(read_only=True) | ||||
|  | ||||
|     def create(self, validated_data: dict) -> Model: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def update(self, instance: Model, validated_data: dict) -> Model: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class DiagramElement: | ||||
|     """Single element used in a diagram""" | ||||
|  | ||||
|     identifier: str | ||||
|     type: str | ||||
|     rest: str | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"{self.identifier}=>{self.type}: {self.rest}" | ||||
|  | ||||
|  | ||||
| class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|     """Flow Viewset""" | ||||
|  | ||||
| @ -198,7 +171,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|     def export(self, request: Request, slug: str) -> Response: | ||||
|         """Export flow to .yaml file""" | ||||
|         flow = self.get_object() | ||||
|         exporter = Exporter(flow) | ||||
|         exporter = FlowExporter(flow) | ||||
|         response = HttpResponse(content=exporter.export_to_string()) | ||||
|         response["Content-Disposition"] = f'attachment; filename="{flow.slug}.yaml"' | ||||
|         return response | ||||
| @ -208,84 +181,9 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | ||||
|     # pylint: disable=unused-argument | ||||
|     def diagram(self, request: Request, slug: str) -> Response: | ||||
|         """Return diagram for flow with slug `slug`, in the format used by flowchart.js""" | ||||
|         flow = self.get_object() | ||||
|         header = [ | ||||
|             DiagramElement("st", "start", "Start"), | ||||
|         ] | ||||
|         body: list[DiagramElement] = [] | ||||
|         footer = [] | ||||
|         # Collect all elements we need | ||||
|         # First, policies bound to the flow itself | ||||
|         for p_index, policy_binding in enumerate( | ||||
|             get_objects_for_user(request.user, "authentik_policies.view_policybinding") | ||||
|             .filter(target=flow) | ||||
|             .exclude(policy__isnull=True) | ||||
|             .order_by("order") | ||||
|         ): | ||||
|             body.append( | ||||
|                 DiagramElement( | ||||
|                     f"flow_policy_{p_index}", | ||||
|                     "condition", | ||||
|                     _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||
|                     + "\n" | ||||
|                     + policy_binding.policy.name, | ||||
|                 ) | ||||
|             ) | ||||
|         # Collect all stages | ||||
|         for s_index, stage_binding in enumerate( | ||||
|             get_objects_for_user(request.user, "authentik_flows.view_flowstagebinding") | ||||
|             .filter(target=flow) | ||||
|             .order_by("order") | ||||
|         ): | ||||
|             # First all policies bound to stages since they execute before stages | ||||
|             for p_index, policy_binding in enumerate( | ||||
|                 get_objects_for_user(request.user, "authentik_policies.view_policybinding") | ||||
|                 .filter(target=stage_binding) | ||||
|                 .exclude(policy__isnull=True) | ||||
|                 .order_by("order") | ||||
|             ): | ||||
|                 body.append( | ||||
|                     DiagramElement( | ||||
|                         f"stage_{s_index}_policy_{p_index}", | ||||
|                         "condition", | ||||
|                         _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||
|                         + "\n" | ||||
|                         + policy_binding.policy.name, | ||||
|                     ) | ||||
|                 ) | ||||
|             body.append( | ||||
|                 DiagramElement( | ||||
|                     f"stage_{s_index}", | ||||
|                     "operation", | ||||
|                     _("Stage (%(type)s)" % {"type": stage_binding.stage._meta.verbose_name}) | ||||
|                     + "\n" | ||||
|                     + stage_binding.stage.name, | ||||
|                 ) | ||||
|             ) | ||||
|         # If the 2nd last element is a policy, we need to have an item to point to | ||||
|         # for a negative case | ||||
|         body.append( | ||||
|             DiagramElement("e", "end", "End|future"), | ||||
|         ) | ||||
|         if len(body) == 1: | ||||
|             footer.append("st(right)->e") | ||||
|         else: | ||||
|             # Actual diagram flow | ||||
|             footer.append(f"st(right)->{body[0].identifier}") | ||||
|             for index in range(len(body) - 1): | ||||
|                 element: DiagramElement = body[index] | ||||
|                 if element.type == "condition": | ||||
|                     # Policy passes, link policy yes to next stage | ||||
|                     footer.append(f"{element.identifier}(yes, right)->{body[index + 1].identifier}") | ||||
|                     # Policy doesn't pass, go to stage after next stage | ||||
|                     no_element = body[index + 1] | ||||
|                     if no_element.type != "end": | ||||
|                         no_element = body[index + 2] | ||||
|                     footer.append(f"{element.identifier}(no, bottom)->{no_element.identifier}") | ||||
|                 elif element.type == "operation": | ||||
|                     footer.append(f"{element.identifier}(bottom)->{body[index + 1].identifier}") | ||||
|         diagram = "\n".join([str(x) for x in header + body + footer]) | ||||
|         return Response({"diagram": diagram}) | ||||
|         diagram = FlowDiagram(self.get_object(), request.user) | ||||
|         output = diagram.build() | ||||
|         return Response({"diagram": output}) | ||||
|  | ||||
|     @permission_required("authentik_flows.change_flow") | ||||
|     @extend_schema( | ||||
|  | ||||
							
								
								
									
										206
									
								
								authentik/flows/api/flows_diagram.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								authentik/flows/api/flows_diagram.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,206 @@ | ||||
| """Flows Diagram API""" | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Optional | ||||
|  | ||||
| from django.utils.translation import gettext as _ | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from rest_framework.serializers import CharField | ||||
|  | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.core.models import User | ||||
| from authentik.flows.models import Flow, FlowStageBinding | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class DiagramElement: | ||||
|     """Single element used in a diagram""" | ||||
|  | ||||
|     identifier: str | ||||
|     description: str | ||||
|     action: Optional[str] = None | ||||
|     source: Optional[list["DiagramElement"]] = None | ||||
|  | ||||
|     style: list[str] = field(default_factory=lambda: ["[", "]"]) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         element = f'{self.identifier}{self.style[0]}"{self.description}"{self.style[1]}' | ||||
|         if self.action is not None: | ||||
|             if self.action != "": | ||||
|                 element = f"--{self.action}--> {element}" | ||||
|             else: | ||||
|                 element = f"--> {element}" | ||||
|         if self.source: | ||||
|             source_element = [] | ||||
|             for source in self.source: | ||||
|                 source_element.append(f"{source.identifier} {element}") | ||||
|             return "\n".join(source_element) | ||||
|         return element | ||||
|  | ||||
|  | ||||
| class FlowDiagramSerializer(PassiveSerializer): | ||||
|     """response of the flow's diagram action""" | ||||
|  | ||||
|     diagram = CharField(read_only=True) | ||||
|  | ||||
|  | ||||
| class FlowDiagram: | ||||
|     """Generate flow chart fow a flow""" | ||||
|  | ||||
|     flow: Flow | ||||
|     user: User | ||||
|  | ||||
|     def __init__(self, flow: Flow, user: User) -> None: | ||||
|         self.flow = flow | ||||
|         self.user = user | ||||
|  | ||||
|     def get_flow_policies(self, parent_elements: list[DiagramElement]) -> list[DiagramElement]: | ||||
|         """Collect all policies bound to the flow""" | ||||
|         elements = [] | ||||
|         for p_index, policy_binding in enumerate( | ||||
|             get_objects_for_user(self.user, "authentik_policies.view_policybinding") | ||||
|             .filter(target=self.flow) | ||||
|             .exclude(policy__isnull=True) | ||||
|             .order_by("order") | ||||
|         ): | ||||
|             element = DiagramElement( | ||||
|                 f"flow_policy_{p_index}", | ||||
|                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||
|                 + "\n" | ||||
|                 + policy_binding.policy.name, | ||||
|                 _("Binding %(order)d" % {"order": policy_binding.order}), | ||||
|                 parent_elements, | ||||
|                 style=["{{", "}}"], | ||||
|             ) | ||||
|             elements.append(element) | ||||
|         return elements | ||||
|  | ||||
|     def get_stage_policies( | ||||
|         self, | ||||
|         stage_index: int, | ||||
|         stage_binding: FlowStageBinding, | ||||
|         parent_elements: list[DiagramElement], | ||||
|     ) -> list[DiagramElement]: | ||||
|         """First all policies bound to stages since they execute before stages""" | ||||
|         elements = [] | ||||
|         for p_index, policy_binding in enumerate( | ||||
|             get_objects_for_user(self.user, "authentik_policies.view_policybinding") | ||||
|             .filter(target=stage_binding) | ||||
|             .exclude(policy__isnull=True) | ||||
|             .order_by("order") | ||||
|         ): | ||||
|             element = DiagramElement( | ||||
|                 f"stage_{stage_index}_policy_{p_index}", | ||||
|                 _("Policy (%(type)s)" % {"type": policy_binding.policy._meta.verbose_name}) | ||||
|                 + "\n" | ||||
|                 + policy_binding.policy.name, | ||||
|                 "", | ||||
|                 parent_elements, | ||||
|                 style=["{{", "}}"], | ||||
|             ) | ||||
|             elements.append(element) | ||||
|         return elements | ||||
|  | ||||
|     def get_stages(self, parent_elements: list[DiagramElement]) -> list[str | DiagramElement]: | ||||
|         """Collect all stages""" | ||||
|         elements = [] | ||||
|         stages = [] | ||||
|         for s_index, stage_binding in enumerate( | ||||
|             get_objects_for_user(self.user, "authentik_flows.view_flowstagebinding") | ||||
|             .filter(target=self.flow) | ||||
|             .order_by("order") | ||||
|         ): | ||||
|             stage_policies = self.get_stage_policies(s_index, stage_binding, parent_elements) | ||||
|             elements.extend(stage_policies) | ||||
|  | ||||
|             action = "" | ||||
|             if len(stage_policies) > 0: | ||||
|                 action = _("Policy passed") | ||||
|  | ||||
|             element = DiagramElement( | ||||
|                 f"stage_{s_index}", | ||||
|                 _("Stage (%(type)s)" % {"type": stage_binding.stage._meta.verbose_name}) | ||||
|                 + "\n" | ||||
|                 + stage_binding.stage.name, | ||||
|                 action, | ||||
|                 stage_policies, | ||||
|                 style=["([", "])"], | ||||
|             ) | ||||
|             stages.append(element) | ||||
|  | ||||
|             parent_elements = [element] | ||||
|  | ||||
|             # This adds connections for policy denies, but retroactively, as we can't really | ||||
|             # look ahead | ||||
|             # Check if we have a stage behind us and if it has any sources | ||||
|             if s_index > 0: | ||||
|                 last_stage: DiagramElement = stages[s_index - 1] | ||||
|                 if last_stage.source and len(last_stage.source) > 0: | ||||
|                     # If it has any sources, add a connection from each of that stage's sources | ||||
|                     # to this stage | ||||
|                     for source in last_stage.source: | ||||
|                         elements.append( | ||||
|                             DiagramElement( | ||||
|                                 element.identifier, | ||||
|                                 element.description, | ||||
|                                 _("Policy denied"), | ||||
|                                 [source], | ||||
|                                 style=element.style, | ||||
|                             ) | ||||
|                         ) | ||||
|  | ||||
|         if len(stages) > 0: | ||||
|             elements.append( | ||||
|                 DiagramElement( | ||||
|                     "done", | ||||
|                     _("End of the flow"), | ||||
|                     "", | ||||
|                     [stages[-1]], | ||||
|                     style=["[[", "]]"], | ||||
|                 ), | ||||
|             ) | ||||
|         return stages + elements | ||||
|  | ||||
|     def build(self) -> str: | ||||
|         """Build flowchart""" | ||||
|         all_elements = [ | ||||
|             "graph TD", | ||||
|         ] | ||||
|  | ||||
|         pre_flow_policies_element = DiagramElement( | ||||
|             "flow_pre", _("Pre-flow policies"), style=["[[", "]]"] | ||||
|         ) | ||||
|         flow_policies = self.get_flow_policies([pre_flow_policies_element]) | ||||
|         if len(flow_policies) > 0: | ||||
|             all_elements.append(pre_flow_policies_element) | ||||
|             all_elements.extend(flow_policies) | ||||
|             all_elements.append( | ||||
|                 DiagramElement( | ||||
|                     "done", | ||||
|                     _("End of the flow"), | ||||
|                     _("Policy denied"), | ||||
|                     flow_policies, | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         flow_element = DiagramElement( | ||||
|             "flow_start", | ||||
|             _("Flow") + "\n" + self.flow.name, | ||||
|             "" if len(flow_policies) > 0 else None, | ||||
|             source=flow_policies, | ||||
|             style=["[[", "]]"], | ||||
|         ) | ||||
|         all_elements.append(flow_element) | ||||
|  | ||||
|         stages = self.get_stages([flow_element]) | ||||
|         all_elements.extend(stages) | ||||
|         if len(stages) < 1: | ||||
|             all_elements.append( | ||||
|                 DiagramElement( | ||||
|                     "done", | ||||
|                     _("End of the flow"), | ||||
|                     "", | ||||
|                     [flow_element], | ||||
|                     style=["[[", "]]"], | ||||
|                 ), | ||||
|             ) | ||||
|         return "\n".join([str(x) for x in all_elements]) | ||||
| @ -1,7 +1,7 @@ | ||||
| """authentik flows app config""" | ||||
| from prometheus_client import Gauge, Histogram | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.utils.reflection import all_subclasses | ||||
|  | ||||
| GAUGE_FLOWS_CACHED = Gauge( | ||||
| @ -28,7 +28,7 @@ class AuthentikFlowsConfig(ManagedAppConfig): | ||||
|         """Load flows signals""" | ||||
|         self.import_module("authentik.flows.signals") | ||||
|  | ||||
|     def reconcile_stages_loaded(self): | ||||
|     def reconcile_load_stages(self): | ||||
|         """Ensure all stages are loaded""" | ||||
|         from authentik.flows.models import Stage | ||||
|  | ||||
|  | ||||
| @ -1,14 +1,14 @@ | ||||
| """Challenge helpers""" | ||||
| from dataclasses import asdict, is_dataclass | ||||
| from enum import Enum | ||||
| from traceback import format_tb | ||||
| from typing import TYPE_CHECKING, Optional, TypedDict | ||||
| from uuid import UUID | ||||
|  | ||||
| from django.core.serializers.json import DjangoJSONEncoder | ||||
| from django.db import models | ||||
| from django.http import JsonResponse | ||||
| from rest_framework.fields import ChoiceField, DictField | ||||
| from rest_framework.serializers import CharField | ||||
| from rest_framework.fields import CharField, ChoiceField, DictField | ||||
|  | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
|  | ||||
| @ -90,6 +90,34 @@ class WithUserInfoChallenge(Challenge): | ||||
|     pending_user_avatar = CharField() | ||||
|  | ||||
|  | ||||
| class FlowErrorChallenge(WithUserInfoChallenge): | ||||
|     """Challenge class when an unhandled error occurs during a stage. Normal users | ||||
|     are shown an error message, superusers are shown a full stacktrace.""" | ||||
|  | ||||
|     component = CharField(default="xak-flow-error") | ||||
|  | ||||
|     request_id = CharField() | ||||
|  | ||||
|     error = CharField(required=False) | ||||
|     traceback = CharField(required=False) | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         request = kwargs.pop("request", None) | ||||
|         error = kwargs.pop("error", None) | ||||
|         super().__init__(*args, **kwargs) | ||||
|         if not request or not error: | ||||
|             return | ||||
|         self.request_id = request.request_id | ||||
|         from authentik.core.models import USER_ATTRIBUTE_DEBUG | ||||
|  | ||||
|         if request.user and request.user.is_authenticated: | ||||
|             if request.user.is_superuser or request.user.group_attributes(request).get( | ||||
|                 USER_ATTRIBUTE_DEBUG, False | ||||
|             ): | ||||
|                 self.error = error | ||||
|                 self.traceback = "".join(format_tb(self.error.__traceback__)) | ||||
|  | ||||
|  | ||||
| class AccessDeniedChallenge(WithUserInfoChallenge): | ||||
|     """Challenge when a flow's active stage calls `stage_invalid()`.""" | ||||
|  | ||||
|  | ||||
| @ -32,7 +32,7 @@ class FlowPlanProcess(PROCESS_CLASS):  # pragma: no cover | ||||
|  | ||||
|     def run(self): | ||||
|         """Execute 1000 flow plans""" | ||||
|         print(f"Proc {self.index} Running") | ||||
|         LOGGER.info(f"Proc {self.index} Running") | ||||
|  | ||||
|         def test_inner(): | ||||
|             planner = FlowPlanner(self.flow) | ||||
|  | ||||
| @ -19,25 +19,6 @@ def update_flow_designation(apps: Apps, schema_editor: BaseDatabaseSchemaEditor) | ||||
|             flow.save() | ||||
|  | ||||
|  | ||||
| # First stage for default-source-enrollment flow (prompt stage) | ||||
| # needs to have its policy re-evaluated | ||||
| def update_default_source_enrollment_flow_binding( | ||||
|     apps: Apps, schema_editor: BaseDatabaseSchemaEditor | ||||
| ): | ||||
|     Flow = apps.get_model("authentik_flows", "Flow") | ||||
|     FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") | ||||
|     db_alias = schema_editor.connection.alias | ||||
|  | ||||
|     flows = Flow.objects.using(db_alias).filter(slug="default-source-enrollment") | ||||
|     if not flows.exists(): | ||||
|         return | ||||
|     flow = flows.first() | ||||
|  | ||||
|     binding = FlowStageBinding.objects.get(target=flow, order=0) | ||||
|     binding.re_evaluate_policies = True | ||||
|     binding.save() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     replaces = [ | ||||
| @ -101,9 +82,6 @@ class Migration(migrations.Migration): | ||||
|                 help_text="When this option is enabled, the planner will re-evaluate policies bound to this binding.", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.RunPython( | ||||
|             code=update_default_source_enrollment_flow_binding, | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="flowstagebinding", | ||||
|             name="re_evaluate_policies", | ||||
|  | ||||
| @ -1,22 +0,0 @@ | ||||
| {% load i18n %} | ||||
|  | ||||
| <style> | ||||
|     .ak-exception { | ||||
|         font-family: monospace; | ||||
|         overflow-x: scroll; | ||||
|     } | ||||
| </style> | ||||
|  | ||||
| <header class="pf-c-login__main-header"> | ||||
|     <h1 class="pf-c-title pf-m-3xl"> | ||||
|         {% trans 'Whoops!' %} | ||||
|     </h1> | ||||
| </header> | ||||
| <div class="pf-c-login__main-body"> | ||||
|     <h3> | ||||
|         {% trans 'Something went wrong! Please try again later.' %} | ||||
|     </h3> | ||||
|     {% if debug %} | ||||
|     <pre class="ak-exception">{{ tb }}{{ error }}</pre> | ||||
|     {% endif %} | ||||
| </div> | ||||
| @ -9,22 +9,20 @@ from authentik.policies.dummy.models import DummyPolicy | ||||
| from authentik.policies.models import PolicyBinding | ||||
| from authentik.stages.dummy.models import DummyStage | ||||
|  | ||||
| DIAGRAM_EXPECTED = """st=>start: Start | ||||
| stage_0=>operation: Stage (Dummy Stage) | ||||
| dummy1 | ||||
| stage_1_policy_0=>condition: Policy (Dummy Policy) | ||||
| test | ||||
| stage_1=>operation: Stage (Dummy Stage) | ||||
| dummy2 | ||||
| e=>end: End|future | ||||
| st(right)->stage_0 | ||||
| stage_0(bottom)->stage_1_policy_0 | ||||
| stage_1_policy_0(yes, right)->stage_1 | ||||
| stage_1_policy_0(no, bottom)->e | ||||
| stage_1(bottom)->e""" | ||||
| DIAGRAM_SHORT_EXPECTED = """st=>start: Start | ||||
| e=>end: End|future | ||||
| st(right)->e""" | ||||
| DIAGRAM_EXPECTED = """graph TD | ||||
| flow_start[["Flow | ||||
| test-default-context"]] | ||||
| --> stage_0(["Stage (Dummy Stage) | ||||
| dummy1"]) | ||||
| stage_1_policy_0 --Policy passed--> stage_1(["Stage (Dummy Stage) | ||||
| dummy2"]) | ||||
| stage_0 --> stage_1_policy_0{{"Policy (Dummy Policy) | ||||
| dummy2-policy"}} | ||||
| stage_1 --> done[["End of the flow"]]""" | ||||
| DIAGRAM_SHORT_EXPECTED = """graph TD | ||||
| flow_start[["Flow | ||||
| test-default-context"]] | ||||
| flow_start --> done[["End of the flow"]]""" | ||||
|  | ||||
|  | ||||
| class TestFlowsAPI(APITestCase): | ||||
| @ -55,7 +53,9 @@ class TestFlowsAPI(APITestCase): | ||||
|             slug="test-default-context", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         false_policy = DummyPolicy.objects.create(name="test", result=False, wait_min=1, wait_max=2) | ||||
|         false_policy = DummyPolicy.objects.create( | ||||
|             name="dummy2-policy", result=False, wait_min=1, wait_max=2 | ||||
|         ) | ||||
|  | ||||
|         FlowStageBinding.objects.create( | ||||
|             target=flow, stage=DummyStage.objects.create(name="dummy1"), order=0 | ||||
|  | ||||
| @ -6,10 +6,9 @@ from django.test.client import RequestFactory | ||||
| from django.urls.base import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.flows.challenge import ChallengeTypes | ||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, InvalidResponseAction | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.flows.models import FlowDesignation, FlowStageBinding, InvalidResponseAction | ||||
| from authentik.stages.dummy.models import DummyStage | ||||
| from authentik.stages.identification.models import IdentificationStage, UserFields | ||||
|  | ||||
| @ -24,11 +23,7 @@ class TestFlowInspector(APITestCase): | ||||
|  | ||||
|     def test(self): | ||||
|         """test inspector""" | ||||
|         flow = Flow.objects.create( | ||||
|             name=generate_id(), | ||||
|             slug=generate_id(), | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         flow = create_test_flow(FlowDesignation.AUTHENTICATION) | ||||
|  | ||||
|         # Stage 1 is an identification stage | ||||
|         ident_stage = IdentificationStage.objects.create( | ||||
| @ -55,7 +50,7 @@ class TestFlowInspector(APITestCase): | ||||
|                 "flow_info": { | ||||
|                     "background": flow.background_url, | ||||
|                     "cancel_url": reverse("authentik_flows:cancel"), | ||||
|                     "title": "", | ||||
|                     "title": flow.title, | ||||
|                     "layout": "stacked", | ||||
|                 }, | ||||
|                 "type": ChallengeTypes.NATIVE.value, | ||||
|  | ||||
| @ -8,9 +8,10 @@ from django.urls import reverse | ||||
| from guardian.shortcuts import get_anonymous_user | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.core.tests.utils import create_test_flow | ||||
| from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException | ||||
| from authentik.flows.markers import ReevaluateMarker, StageMarker | ||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding | ||||
| from authentik.flows.models import FlowDesignation, FlowStageBinding | ||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key | ||||
| from authentik.lib.tests.utils import dummy_get_response | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
| @ -32,11 +33,7 @@ class TestFlowPlanner(TestCase): | ||||
|  | ||||
|     def test_empty_plan(self): | ||||
|         """Test that empty plan raises exception""" | ||||
|         flow = Flow.objects.create( | ||||
|             name="test-empty", | ||||
|             slug="test-empty", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         flow = create_test_flow() | ||||
|         request = self.request_factory.get( | ||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), | ||||
|         ) | ||||
| @ -52,11 +49,7 @@ class TestFlowPlanner(TestCase): | ||||
|     ) | ||||
|     def test_non_applicable_plan(self): | ||||
|         """Test that empty plan raises exception""" | ||||
|         flow = Flow.objects.create( | ||||
|             name="test-empty", | ||||
|             slug="test-empty", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         flow = create_test_flow() | ||||
|         request = self.request_factory.get( | ||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), | ||||
|         ) | ||||
| @ -69,11 +62,7 @@ class TestFlowPlanner(TestCase): | ||||
|     @patch("authentik.flows.planner.cache", CACHE_MOCK) | ||||
|     def test_planner_cache(self): | ||||
|         """Test planner cache""" | ||||
|         flow = Flow.objects.create( | ||||
|             name="test-cache", | ||||
|             slug="test-cache", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         flow = create_test_flow(FlowDesignation.AUTHENTICATION) | ||||
|         FlowStageBinding.objects.create( | ||||
|             target=flow, stage=DummyStage.objects.create(name="dummy"), order=0 | ||||
|         ) | ||||
| @ -92,11 +81,7 @@ class TestFlowPlanner(TestCase): | ||||
|  | ||||
|     def test_planner_default_context(self): | ||||
|         """Test planner with default_context""" | ||||
|         flow = Flow.objects.create( | ||||
|             name="test-default-context", | ||||
|             slug="test-default-context", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         flow = create_test_flow() | ||||
|         FlowStageBinding.objects.create( | ||||
|             target=flow, stage=DummyStage.objects.create(name="dummy"), order=0 | ||||
|         ) | ||||
| @ -113,11 +98,7 @@ class TestFlowPlanner(TestCase): | ||||
|  | ||||
|     def test_planner_marker_reevaluate(self): | ||||
|         """Test that the planner creates the proper marker""" | ||||
|         flow = Flow.objects.create( | ||||
|             name="test-default-context", | ||||
|             slug="test-default-context", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         flow = create_test_flow() | ||||
|  | ||||
|         FlowStageBinding.objects.create( | ||||
|             target=flow, | ||||
| @ -138,11 +119,7 @@ class TestFlowPlanner(TestCase): | ||||
|  | ||||
|     def test_planner_reevaluate_actual(self): | ||||
|         """Test planner with re-evaluate""" | ||||
|         flow = Flow.objects.create( | ||||
|             name="test-default-context", | ||||
|             slug="test-default-context", | ||||
|             designation=FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|         flow = create_test_flow() | ||||
|         false_policy = DummyPolicy.objects.create(result=False, wait_min=1, wait_max=2) | ||||
|  | ||||
|         binding = FlowStageBinding.objects.create( | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """authentik multi-stage authentication engine""" | ||||
| from copy import deepcopy | ||||
| from traceback import format_tb | ||||
| from typing import Any, Optional | ||||
| from typing import Optional | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.mixins import LoginRequiredMixin | ||||
| @ -23,12 +22,12 @@ from sentry_sdk.api import set_tag | ||||
| from sentry_sdk.hub import Hub | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| from authentik.core.models import USER_ATTRIBUTE_DEBUG | ||||
| from authentik.events.models import Event, EventAction, cleanse_dict | ||||
| from authentik.flows.challenge import ( | ||||
|     Challenge, | ||||
|     ChallengeResponse, | ||||
|     ChallengeTypes, | ||||
|     FlowErrorChallenge, | ||||
|     HttpChallengeResponse, | ||||
|     RedirectChallenge, | ||||
|     ShellChallenge, | ||||
| @ -153,6 +152,7 @@ class FlowExecutorView(APIView): | ||||
|         token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first() | ||||
|         if not token: | ||||
|             return None | ||||
|         plan = None | ||||
|         try: | ||||
|             plan = token.plan | ||||
|         except (AttributeError, EOFError, ImportError, IndexError) as exc: | ||||
| @ -253,7 +253,9 @@ class FlowExecutorView(APIView): | ||||
|             action=EventAction.SYSTEM_EXCEPTION, | ||||
|             message=exception_to_string(exc), | ||||
|         ).from_http(self.request) | ||||
|         return to_stage_response(self.request, FlowErrorResponse(self.request, exc)) | ||||
|         return to_stage_response( | ||||
|             self.request, HttpChallengeResponse(FlowErrorChallenge(self.request, exc)) | ||||
|         ) | ||||
|  | ||||
|     @extend_schema( | ||||
|         responses={ | ||||
| @ -440,30 +442,6 @@ class FlowExecutorView(APIView): | ||||
|                 del self.request.session[key] | ||||
|  | ||||
|  | ||||
| class FlowErrorResponse(TemplateResponse): | ||||
|     """Response class when an unhandled error occurs during a stage. Normal users | ||||
|     are shown an error message, superusers are shown a full stacktrace.""" | ||||
|  | ||||
|     error: Exception | ||||
|  | ||||
|     def __init__(self, request: HttpRequest, error: Exception) -> None: | ||||
|         # For some reason pyright complains about keyword argument usage here | ||||
|         # pyright: reportGeneralTypeIssues=false | ||||
|         super().__init__(request=request, template="flows/error.html") | ||||
|         self.error = error | ||||
|  | ||||
|     def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: | ||||
|         if not context: | ||||
|             context = {} | ||||
|         context["error"] = self.error | ||||
|         if self._request.user and self._request.user.is_authenticated: | ||||
|             if self._request.user.is_superuser or self._request.user.group_attributes( | ||||
|                 self._request | ||||
|             ).get(USER_ATTRIBUTE_DEBUG, False): | ||||
|                 context["tb"] = "".join(format_tb(self.error.__traceback__)) | ||||
|         return context | ||||
|  | ||||
|  | ||||
| class CancelView(View): | ||||
|     """View which canels the currently active plan""" | ||||
|  | ||||
|  | ||||
| @ -20,7 +20,7 @@ ENV_PREFIX = "AUTHENTIK" | ||||
| ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local") | ||||
|  | ||||
|  | ||||
| def get_path_from_dict(root: dict, path: str, sep=".", default=None): | ||||
| def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any: | ||||
|     """Recursively walk through `root`, checking each part of `path` split by `sep`. | ||||
|     If at any point a dict does not exist, return default""" | ||||
|     for comp in path.split(sep): | ||||
| @ -62,7 +62,7 @@ class ConfigLoader: | ||||
|                         self.update_from_file(env_file) | ||||
|         self.update_from_env() | ||||
|  | ||||
|     def _log(self, level: str, message: str, **kwargs): | ||||
|     def log(self, level: str, message: str, **kwargs): | ||||
|         """Custom Log method, we want to ensure ConfigLoader always logs JSON even when | ||||
|         'structlog' or 'logging' hasn't been configured yet.""" | ||||
|         output = { | ||||
| @ -95,7 +95,7 @@ class ConfigLoader: | ||||
|                 with open(url.path, "r", encoding="utf8") as _file: | ||||
|                     value = _file.read() | ||||
|             except OSError as exc: | ||||
|                 self._log("error", f"Failed to read config value from {url.path}: {exc}") | ||||
|                 self.log("error", f"Failed to read config value from {url.path}: {exc}") | ||||
|                 value = url.query | ||||
|         return value | ||||
|  | ||||
| @ -105,12 +105,12 @@ class ConfigLoader: | ||||
|             with open(path, encoding="utf8") as file: | ||||
|                 try: | ||||
|                     self.update(self.__config, yaml.safe_load(file)) | ||||
|                     self._log("debug", "Loaded config", file=path) | ||||
|                     self.log("debug", "Loaded config", file=path) | ||||
|                     self.loaded_file.append(path) | ||||
|                 except yaml.YAMLError as exc: | ||||
|                     raise ImproperlyConfigured from exc | ||||
|         except PermissionError as exc: | ||||
|             self._log( | ||||
|             self.log( | ||||
|                 "warning", | ||||
|                 "Permission denied while reading file", | ||||
|                 path=path, | ||||
| @ -144,7 +144,7 @@ class ConfigLoader: | ||||
|             current_obj[dot_parts[-1]] = value | ||||
|             idx += 1 | ||||
|         if idx > 0: | ||||
|             self._log("debug", "Loaded environment variables", count=idx) | ||||
|             self.log("debug", "Loaded environment variables", count=idx) | ||||
|             self.update(self.__config, outer) | ||||
|  | ||||
|     @contextmanager | ||||
| @ -152,8 +152,10 @@ class ConfigLoader: | ||||
|         """Context manager for unittests to patch a value""" | ||||
|         original_value = self.y(path) | ||||
|         self.y_set(path, value) | ||||
|         yield | ||||
|         self.y_set(path, original_value) | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             self.y_set(path, original_value) | ||||
|  | ||||
|     @property | ||||
|     def raw(self) -> dict: | ||||
| @ -178,7 +180,7 @@ class ConfigLoader: | ||||
|             # pyright: reportGeneralTypeIssues=false | ||||
|             if comp not in root: | ||||
|                 root[comp] = {} | ||||
|             root = root.get(comp) | ||||
|             root = root.get(comp, {}) | ||||
|         root[path_parts[-1]] = value | ||||
|  | ||||
|     def y_bool(self, path: str, default=False) -> bool: | ||||
|  | ||||
| @ -36,7 +36,7 @@ error_reporting: | ||||
|   enabled: false | ||||
|   environment: customer | ||||
|   send_pii: false | ||||
|   sample_rate: 0.3 | ||||
|   sample_rate: 0.1 | ||||
|  | ||||
| # Global email settings | ||||
| email: | ||||
| @ -80,3 +80,8 @@ default_token_length: 128 | ||||
| impersonation: true | ||||
|  | ||||
| blueprints_dir: /blueprints | ||||
|  | ||||
| web: | ||||
|   # No default here as it's set dynamically | ||||
|   # workers: 2 | ||||
|   threads: 4 | ||||
|  | ||||
| @ -1,16 +1,20 @@ | ||||
| """authentik expression policy evaluator""" | ||||
| import re | ||||
| from ipaddress import ip_address, ip_network | ||||
| from textwrap import indent | ||||
| from typing import Any, Iterable, Optional | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django_otp import devices_for_user | ||||
| from rest_framework.serializers import ValidationError | ||||
| from sentry_sdk.hub import Hub | ||||
| from sentry_sdk.tracing import Span | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.events.models import Event | ||||
| from authentik.lib.utils.http import get_http_session | ||||
| from authentik.policies.types import PolicyRequest | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -26,7 +30,8 @@ class BaseEvaluator: | ||||
|     # Filename used for exec | ||||
|     _filename: str | ||||
|  | ||||
|     def __init__(self): | ||||
|     def __init__(self, filename: Optional[str] = None): | ||||
|         self._filename = filename if filename else "BaseEvaluator" | ||||
|         # update website/docs/expressions/_objects.md | ||||
|         # update website/docs/expressions/_functions.md | ||||
|         self._globals = { | ||||
| @ -35,11 +40,14 @@ class BaseEvaluator: | ||||
|             "list_flatten": BaseEvaluator.expr_flatten, | ||||
|             "ak_is_group_member": BaseEvaluator.expr_is_group_member, | ||||
|             "ak_user_by": BaseEvaluator.expr_user_by, | ||||
|             "ak_logger": get_logger(), | ||||
|             "ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator, | ||||
|             "ak_create_event": self.expr_event_create, | ||||
|             "ak_logger": get_logger(self._filename), | ||||
|             "requests": get_http_session(), | ||||
|             "ip_address": ip_address, | ||||
|             "ip_network": ip_network, | ||||
|         } | ||||
|         self._context = {} | ||||
|         self._filename = "BaseEvalautor" | ||||
|  | ||||
|     @staticmethod | ||||
|     def expr_flatten(value: list[Any] | Any) -> Optional[Any]: | ||||
| @ -60,6 +68,11 @@ class BaseEvaluator: | ||||
|         """Expression Filter to run re.sub""" | ||||
|         return re.sub(regex, repl, value) | ||||
|  | ||||
|     @staticmethod | ||||
|     def expr_is_group_member(user: User, **group_filters) -> bool: | ||||
|         """Check if `user` is member of group with name `group_name`""" | ||||
|         return user.ak_groups.filter(**group_filters).exists() | ||||
|  | ||||
|     @staticmethod | ||||
|     def expr_user_by(**filters) -> Optional[User]: | ||||
|         """Get user by filters""" | ||||
| @ -72,15 +85,37 @@ class BaseEvaluator: | ||||
|             return None | ||||
|  | ||||
|     @staticmethod | ||||
|     def expr_is_group_member(user: User, **group_filters) -> bool: | ||||
|         """Check if `user` is member of group with name `group_name`""" | ||||
|         return user.ak_groups.filter(**group_filters).exists() | ||||
|     def expr_func_user_has_authenticator(user: User, device_type: Optional[str] = None) -> bool: | ||||
|         """Check if a user has any authenticator devices, optionally matching *device_type*""" | ||||
|         user_devices = devices_for_user(user) | ||||
|         if device_type: | ||||
|             for device in user_devices: | ||||
|                 device_class = device.__class__.__name__.lower().replace("device", "") | ||||
|                 if device_class == device_type: | ||||
|                     return True | ||||
|             return False | ||||
|         return len(list(user_devices)) > 0 | ||||
|  | ||||
|     def expr_event_create(self, action: str, **kwargs): | ||||
|         """Create event with supplied data and try to extract as much relevant data | ||||
|         from the context""" | ||||
|         kwargs["context"] = self._context | ||||
|         event = Event.new( | ||||
|             action, | ||||
|             app=self._filename, | ||||
|             **kwargs, | ||||
|         ) | ||||
|         if "request" in self._context and isinstance(PolicyRequest, self._context["request"]): | ||||
|             policy_request: PolicyRequest = self._context["request"] | ||||
|             if policy_request.http_request: | ||||
|                 event.from_http(policy_request) | ||||
|                 return | ||||
|         event.save() | ||||
|  | ||||
|     def wrap_expression(self, expression: str, params: Iterable[str]) -> str: | ||||
|         """Wrap expression in a function, call it, and save the result as `result`""" | ||||
|         handler_signature = ",".join(params) | ||||
|         full_expression = "" | ||||
|         full_expression += "from ipaddress import ip_address, ip_network\n" | ||||
|         full_expression += f"def handler({handler_signature}):\n" | ||||
|         full_expression += indent(expression, "    ") | ||||
|         full_expression += f"\nresult = handler({handler_signature})" | ||||
|  | ||||
| @ -95,7 +95,7 @@ def traces_sampler(sampling_context: dict) -> float: | ||||
|     # Ignore all healthcheck routes | ||||
|     if path.startswith("/-/health") or path.startswith("/-/metrics"): | ||||
|         return 0 | ||||
|     return float(CONFIG.y("error_reporting.sample_rate", 0.5)) | ||||
|     return float(CONFIG.y("error_reporting.sample_rate", 0.1)) | ||||
|  | ||||
|  | ||||
| def before_send(event: dict, hint: dict) -> Optional[dict]: | ||||
|  | ||||
| @ -20,9 +20,8 @@ def model_tester_factory(test_model: type[Stage]) -> Callable: | ||||
|         try: | ||||
|             model_class = None | ||||
|             if test_model._meta.abstract: | ||||
|                 model_class = test_model.__bases__[0]() | ||||
|             else: | ||||
|                 model_class = test_model() | ||||
|                 return | ||||
|             model_class = test_model() | ||||
|             self.assertTrue(issubclass(model_class.serializer, BaseSerializer)) | ||||
|         except NotImplementedError: | ||||
|             pass | ||||
|  | ||||
| @ -12,5 +12,4 @@ class TestReflectionUtils(TestCase): | ||||
|  | ||||
|     def test_path_to_class(self): | ||||
|         """Test path_to_class""" | ||||
|         self.assertIsNone(path_to_class(None)) | ||||
|         self.assertEqual(path_to_class("datetime.datetime"), datetime) | ||||
|  | ||||
| @ -29,10 +29,8 @@ def class_to_path(cls: type) -> str: | ||||
|     return f"{cls.__module__}.{cls.__name__}" | ||||
|  | ||||
|  | ||||
| def path_to_class(path: str | None) -> type | None: | ||||
| def path_to_class(path: str = "") -> type: | ||||
|     """Import module and return class""" | ||||
|     if not path: | ||||
|         return None | ||||
|     parts = path.split(".") | ||||
|     package = ".".join(parts[:-1]) | ||||
|     _class = getattr(import_module(package), parts[-1]) | ||||
|  | ||||
| @ -8,7 +8,7 @@ def bad_request_message( | ||||
|     request: HttpRequest, | ||||
|     message: str, | ||||
|     title="Bad Request", | ||||
|     template="error/generic.html", | ||||
|     template="if/error.html", | ||||
| ) -> TemplateResponse: | ||||
|     """Return generic error page with message, with status code set to 400""" | ||||
|     return TemplateResponse( | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| from prometheus_client import Gauge | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @ -5,7 +5,7 @@ from enum import IntEnum | ||||
| from typing import Any, Optional | ||||
|  | ||||
| from channels.exceptions import DenyConnection | ||||
| from dacite import from_dict | ||||
| from dacite.core import from_dict | ||||
| from dacite.data import Data | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from dacite import from_dict | ||||
| from dacite.core import from_dict | ||||
| from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi | ||||
|  | ||||
| from authentik.outposts.controllers.base import FIELD_MANAGER | ||||
|  | ||||
| @ -4,7 +4,7 @@ from datetime import datetime | ||||
| from typing import Iterable, Optional | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from dacite import from_dict | ||||
| from dacite.core import from_dict | ||||
| from django.contrib.auth.models import Permission | ||||
| from django.core.cache import cache | ||||
| from django.db import IntegrityError, models, transaction | ||||
| @ -74,7 +74,7 @@ class OutpostConfig: | ||||
|     kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls") | ||||
|     kubernetes_service_type: str = field(default="ClusterIP") | ||||
|     kubernetes_disabled_components: list[str] = field(default_factory=list) | ||||
|     kubernetes_image_pull_secrets: Optional[list[str]] = field(default_factory=list) | ||||
|     kubernetes_image_pull_secrets: list[str] = field(default_factory=list) | ||||
|  | ||||
|  | ||||
| class OutpostModel(Model): | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| """outpost tasks""" | ||||
| from os import R_OK, access | ||||
| from os.path import expanduser | ||||
| from pathlib import Path | ||||
| from socket import gethostname | ||||
| from typing import Any, Optional | ||||
| @ -75,10 +74,14 @@ def outpost_service_connection_state(connection_pk: Any): | ||||
|     ) | ||||
|     if not connection: | ||||
|         return | ||||
|     cls = None | ||||
|     if isinstance(connection, DockerServiceConnection): | ||||
|         cls = DockerClient | ||||
|     if isinstance(connection, KubernetesServiceConnection): | ||||
|         cls = KubernetesClient | ||||
|     if not cls: | ||||
|         LOGGER.warning("No class found for service connection", connection=connection) | ||||
|         return | ||||
|     try: | ||||
|         with cls(connection) as client: | ||||
|             state = client.fetch_state() | ||||
| @ -240,25 +243,25 @@ def _outpost_single_update(outpost: Outpost, layer=None): | ||||
| def outpost_local_connection(): | ||||
|     """Checks the local environment and create Service connections.""" | ||||
|     if not CONFIG.y_bool("outposts.discover"): | ||||
|         LOGGER.debug("Outpost integration discovery is disabled") | ||||
|         LOGGER.info("Outpost integration discovery is disabled") | ||||
|         return | ||||
|     # Explicitly check against token filename, as that's | ||||
|     # only present when the integration is enabled | ||||
|     if Path(SERVICE_TOKEN_FILENAME).exists(): | ||||
|         LOGGER.debug("Detected in-cluster Kubernetes Config") | ||||
|         LOGGER.info("Detected in-cluster Kubernetes Config") | ||||
|         if not KubernetesServiceConnection.objects.filter(local=True).exists(): | ||||
|             LOGGER.debug("Created Service Connection for in-cluster") | ||||
|             KubernetesServiceConnection.objects.create( | ||||
|                 name="Local Kubernetes Cluster", local=True, kubeconfig={} | ||||
|             ) | ||||
|     # For development, check for the existence of a kubeconfig file | ||||
|     kubeconfig_path = expanduser(KUBE_CONFIG_DEFAULT_LOCATION) | ||||
|     if Path(kubeconfig_path).exists(): | ||||
|         LOGGER.debug("Detected kubeconfig") | ||||
|     kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() | ||||
|     if kubeconfig_path.exists(): | ||||
|         LOGGER.info("Detected kubeconfig") | ||||
|         kubeconfig_local_name = f"k8s-{gethostname()}" | ||||
|         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): | ||||
|             LOGGER.debug("Creating kubeconfig Service Connection") | ||||
|             with open(kubeconfig_path, "r", encoding="utf8") as _kubeconfig: | ||||
|             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: | ||||
|                 KubernetesServiceConnection.objects.create( | ||||
|                     name=kubeconfig_local_name, | ||||
|                     kubeconfig=yaml.safe_load(_kubeconfig), | ||||
| @ -266,7 +269,7 @@ def outpost_local_connection(): | ||||
|     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path | ||||
|     socket = Path(unix_socket_path) | ||||
|     if socket.exists() and access(socket, R_OK): | ||||
|         LOGGER.debug("Detected local docker socket") | ||||
|         LOGGER.info("Detected local docker socket") | ||||
|         if len(DockerServiceConnection.objects.filter(local=True)) == 0: | ||||
|             LOGGER.debug("Created Service Connection for docker") | ||||
|             DockerServiceConnection.objects.create( | ||||
|  | ||||
| @ -6,7 +6,7 @@ from channels.testing import WebsocketCommunicator | ||||
| from django.test import TransactionTestCase | ||||
|  | ||||
| from authentik import __version__ | ||||
| from authentik.flows.models import Flow, FlowDesignation | ||||
| from authentik.core.tests.utils import create_test_flow | ||||
| from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction | ||||
| from authentik.outposts.models import Outpost, OutpostType | ||||
| from authentik.providers.proxy.models import ProxyProvider | ||||
| @ -21,9 +21,7 @@ class TestOutpostWS(TransactionTestCase): | ||||
|             name="test", | ||||
|             internal_host="http://localhost", | ||||
|             external_host="http://localhost", | ||||
|             authorization_flow=Flow.objects.create( | ||||
|                 name="foo", slug="foo", designation=FlowDesignation.AUTHORIZATION | ||||
|             ), | ||||
|             authorization_flow=create_test_flow(), | ||||
|         ) | ||||
|         self.outpost: Outpost = Outpost.objects.create( | ||||
|             name="test", | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """authentik policies app config""" | ||||
| from prometheus_client import Gauge, Histogram | ||||
|  | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
| GAUGE_POLICIES_CACHED = Gauge( | ||||
|     "authentik_policies_cached", | ||||
|  | ||||
| @ -1,12 +1,10 @@ | ||||
| """authentik expression policy evaluator""" | ||||
| from ipaddress import ip_address, ip_network | ||||
| from ipaddress import ip_address | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| from django.http import HttpRequest | ||||
| from django_otp import devices_for_user | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.flows.planner import PLAN_CONTEXT_SSO | ||||
| from authentik.lib.expression.evaluator import BaseEvaluator | ||||
| from authentik.lib.utils.http import get_client_ip | ||||
| @ -27,16 +25,14 @@ class PolicyEvaluator(BaseEvaluator): | ||||
|  | ||||
|     policy: Optional["ExpressionPolicy"] = None | ||||
|  | ||||
|     def __init__(self, policy_name: str): | ||||
|         super().__init__() | ||||
|     def __init__(self, policy_name: Optional[str] = None): | ||||
|         super().__init__(policy_name or "PolicyEvaluator") | ||||
|         self._messages = [] | ||||
|         self._context["ak_logger"] = get_logger(policy_name) | ||||
|         # update website/docs/expressions/_objects.md | ||||
|         # update website/docs/expressions/_functions.md | ||||
|         self._context["ak_message"] = self.expr_func_message | ||||
|         self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator | ||||
|         self._context["ak_call_policy"] = self.expr_func_call_policy | ||||
|         self._context["ip_address"] = ip_address | ||||
|         self._context["ip_network"] = ip_network | ||||
|         self._filename = policy_name or "PolicyEvaluator" | ||||
|  | ||||
|     def expr_func_message(self, message: str): | ||||
|         """Wrapper to append to messages list, which is returned with PolicyResult""" | ||||
| @ -52,19 +48,6 @@ class PolicyEvaluator(BaseEvaluator): | ||||
|         proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) | ||||
|         return proc.profiling_wrapper() | ||||
|  | ||||
|     def expr_func_user_has_authenticator( | ||||
|         self, user: User, device_type: Optional[str] = None | ||||
|     ) -> bool: | ||||
|         """Check if a user has any authenticator devices, optionally matching *device_type*""" | ||||
|         user_devices = devices_for_user(user) | ||||
|         if device_type: | ||||
|             for device in user_devices: | ||||
|                 device_class = device.__class__.__name__.lower().replace("device", "") | ||||
|                 if device_class == device_type: | ||||
|                     return True | ||||
|             return False | ||||
|         return len(list(user_devices)) > 0 | ||||
|  | ||||
|     def set_policy_request(self, request: PolicyRequest): | ||||
|         """Update context based on policy request (if http request is given, update that too)""" | ||||
|         # update website/docs/expressions/_objects.md | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| """Authentik reputation_policy app config""" | ||||
| from authentik.blueprints.manager import ManagedAppConfig | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
|  | ||||
| class AuthentikPolicyReputationConfig(ManagedAppConfig): | ||||
|  | ||||
| @ -70,7 +70,6 @@ class PolicyAccessView(AccessMixin, View): | ||||
|         # Check if user is unauthenticated, so we pass the application | ||||
|         # for the identification stage | ||||
|         if not request.user.is_authenticated: | ||||
|             LOGGER.warning("user not authenticated") | ||||
|             return self.handle_no_permission() | ||||
|         # Check permissions | ||||
|         result = self.user_has_access() | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	