Compare commits
	
		
			143 Commits
		
	
	
		
			version/20
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ff2baf502b | |||
| 3b182ca223 | |||
| 8da8890a8e | |||
| 23023ec727 | |||
| 7d84a71a01 | |||
| 192001f193 | |||
| 63734682d2 | |||
| a0cd2d55f8 | |||
| a72c7adfc0 | |||
| e88e02ec85 | |||
| f7661c8bbd | |||
| 9add8479ca | |||
| 4c39e08dd4 | |||
| 44ce2ebece | |||
| f5a8859d00 | |||
| 9ef0e8bc5f | |||
| 60eeafd111 | |||
| 6f3d6efa22 | |||
| 8d3275817b | |||
| ca40d31dac | |||
| 438aac8879 | |||
| 2dfa6c2c82 | |||
| c11435780d | |||
| ee54328589 | |||
| 817d538b8f | |||
| 210775776f | |||
| 2a4ce75bc4 | |||
| b26111fb42 | |||
| e30103aa9f | |||
| dc9203789e | |||
| d70ce2776f | |||
| ad7d65e903 | |||
| 67d54c5209 | |||
| bb244b8338 | |||
| fa04883ac1 | |||
| 6739ded5a9 | |||
| 9a7e5d934e | |||
| 6dc6d19d2d | |||
| 36cbc44ed6 | |||
| 0c591a50e3 | |||
| 7ee655a318 | |||
| 8447e9b9c2 | |||
| 09f92e5bad | |||
| f9a419107a | |||
| 8f0572d11e | |||
| 7ebf793953 | |||
| 63783ee77b | |||
| eba339ba27 | |||
| 0adb5a79f6 | |||
| fa81adf254 | |||
| 558c7bba2a | |||
| 8cd1a42fb9 | |||
| 8cf0e78aa0 | |||
| 3f69a57013 | |||
| f7f12cab10 | |||
| cacaa378c8 | |||
| 33fe85eb96 | |||
| a9744cbf48 | |||
| b91d8a676c | |||
| f19cd1c003 | |||
| 65341cecd0 | |||
| c0cb891078 | |||
| fc1c1a849a | |||
| 5a81ae956f | |||
| 0cac034512 | |||
| 5666995a15 | |||
| 8d3059e4f3 | |||
| a90dc34494 | |||
| 2c6d82593e | |||
| 34bcc2df1a | |||
| c00f2907ea | |||
| b4d528a789 | |||
| d9172cb296 | |||
| bee36cde59 | |||
| d4e7d9d64a | |||
| 7b0265207a | |||
| 7c076579fd | |||
| 7171706d7f | |||
| 9cd46ecbeb | |||
| 5f09ba675d | |||
| 630b926e2a | |||
| 9c6be60ad9 | |||
| a0397fdcf4 | |||
| 59e13e8026 | |||
| 374b51e956 | |||
| 8faa1bf865 | |||
| fc75867218 | |||
| 6d94c2c925 | |||
| eb51dd1379 | |||
| 13a4559c37 | |||
| 4fcf7285d7 | |||
| 0ba9f25155 | |||
| 453c751c7f | |||
| d1eaaef254 | |||
| 3eb466ff4b | |||
| 9f2529c886 | |||
| fb25b28976 | |||
| 612163b82f | |||
| 3c43690a96 | |||
| dd74565c7b | |||
| fb69f67f47 | |||
| 18b48684eb | |||
| 098b0aef6e | |||
| 4ed8171130 | |||
| 335131affc | |||
| bba17a8a67 | |||
| 082df0ec51 | |||
| a03dde8a90 | |||
| 5f04a187ea | |||
| 2b68363452 | |||
| acf1ded1d4 | |||
| a286f999e2 | |||
| 4b6c1da51d | |||
| a81d5a3d41 | |||
| 4d17111233 | |||
| 64cb9812e0 | |||
| ed037b2e3a | |||
| d2be6a8e3a | |||
| a9667eb0f4 | |||
| 7f3988f3c9 | |||
| 4c095a6f2a | |||
| c10b5c3c8c | |||
| 9d920580a1 | |||
| 34ef4af799 | |||
| 5da47b69dd | |||
| 0e0dd2437b | |||
| e42386b150 | |||
| f21f81022e | |||
| e73a468921 | |||
| c0ac053380 | |||
| 4e670295d1 | |||
| 8d7d8d613c | |||
| 4d632a8679 | |||
| ef219198d4 | |||
| ada53362d5 | |||
| a6398f46da | |||
| 56babb2649 | |||
| d25a051eae | |||
| 4a9b788703 | |||
| d4ef321ac2 | |||
| 80c1dbdfbb | |||
| b0af062d74 | |||
| b4e75218f5 | 
@ -1,5 +1,5 @@
 | 
				
			|||||||
[bumpversion]
 | 
					[bumpversion]
 | 
				
			||||||
current_version = 2022.5.3
 | 
					current_version = 2022.6.2
 | 
				
			||||||
tag = True
 | 
					tag = True
 | 
				
			||||||
commit = True
 | 
					commit = True
 | 
				
			||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)\-?(?P<release>.*)
 | 
					parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)\-?(?P<release>.*)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							@ -1,7 +1,7 @@
 | 
				
			|||||||
<!--
 | 
					<!--
 | 
				
			||||||
👋 Hello there! Welcome.
 | 
					👋 Hello there! Welcome.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Please check the [Contributing guidelines](https://github.com/goauthentik/authentik/blob/master/CONTRIBUTING.md#how-can-i-contribute).
 | 
					Please check the [Contributing guidelines](https://github.com/goauthentik/authentik/blob/main/CONTRIBUTING.md#how-can-i-contribute).
 | 
				
			||||||
-->
 | 
					-->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Details
 | 
					# Details
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							@ -3,14 +3,14 @@ name: authentik-ci-main
 | 
				
			|||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
      - next
 | 
					      - next
 | 
				
			||||||
      - version-*
 | 
					      - version-*
 | 
				
			||||||
    paths-ignore:
 | 
					    paths-ignore:
 | 
				
			||||||
      - website
 | 
					      - website
 | 
				
			||||||
  pull_request:
 | 
					  pull_request:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
env:
 | 
					env:
 | 
				
			||||||
  POSTGRES_DB: authentik
 | 
					  POSTGRES_DB: authentik
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										6
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							@ -3,12 +3,12 @@ name: authentik-ci-outpost
 | 
				
			|||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
      - next
 | 
					      - next
 | 
				
			||||||
      - version-*
 | 
					      - version-*
 | 
				
			||||||
  pull_request:
 | 
					  pull_request:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
jobs:
 | 
					jobs:
 | 
				
			||||||
  lint-golint:
 | 
					  lint-golint:
 | 
				
			||||||
@ -110,7 +110,7 @@ jobs:
 | 
				
			|||||||
      - uses: actions/setup-go@v3
 | 
					      - uses: actions/setup-go@v3
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          go-version: "^1.17"
 | 
					          go-version: "^1.17"
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          cache: 'npm'
 | 
					          cache: 'npm'
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										12
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							@ -3,19 +3,19 @@ name: authentik-ci-web
 | 
				
			|||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
      - next
 | 
					      - next
 | 
				
			||||||
      - version-*
 | 
					      - version-*
 | 
				
			||||||
  pull_request:
 | 
					  pull_request:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
jobs:
 | 
					jobs:
 | 
				
			||||||
  lint-eslint:
 | 
					  lint-eslint:
 | 
				
			||||||
    runs-on: ubuntu-latest
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - uses: actions/checkout@v3
 | 
					      - uses: actions/checkout@v3
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          cache: 'npm'
 | 
					          cache: 'npm'
 | 
				
			||||||
@ -31,7 +31,7 @@ jobs:
 | 
				
			|||||||
    runs-on: ubuntu-latest
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - uses: actions/checkout@v3
 | 
					      - uses: actions/checkout@v3
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          cache: 'npm'
 | 
					          cache: 'npm'
 | 
				
			||||||
@ -47,7 +47,7 @@ jobs:
 | 
				
			|||||||
    runs-on: ubuntu-latest
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - uses: actions/checkout@v3
 | 
					      - uses: actions/checkout@v3
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          cache: 'npm'
 | 
					          cache: 'npm'
 | 
				
			||||||
@ -73,7 +73,7 @@ jobs:
 | 
				
			|||||||
    runs-on: ubuntu-latest
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - uses: actions/checkout@v3
 | 
					      - uses: actions/checkout@v3
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          cache: 'npm'
 | 
					          cache: 'npm'
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										6
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							@ -3,19 +3,19 @@ name: authentik-ci-website
 | 
				
			|||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
      - next
 | 
					      - next
 | 
				
			||||||
      - version-*
 | 
					      - version-*
 | 
				
			||||||
  pull_request:
 | 
					  pull_request:
 | 
				
			||||||
    branches:
 | 
					    branches:
 | 
				
			||||||
      - master
 | 
					      - main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
jobs:
 | 
					jobs:
 | 
				
			||||||
  lint-prettier:
 | 
					  lint-prettier:
 | 
				
			||||||
    runs-on: ubuntu-latest
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - uses: actions/checkout@v3
 | 
					      - uses: actions/checkout@v3
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          cache: 'npm'
 | 
					          cache: 'npm'
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							@ -2,10 +2,10 @@ name: "CodeQL"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches: [ master, '*', next, version* ]
 | 
					    branches: [ main, '*', next, version* ]
 | 
				
			||||||
  pull_request:
 | 
					  pull_request:
 | 
				
			||||||
    # The branches below must be a subset of the branches above
 | 
					    # The branches below must be a subset of the branches above
 | 
				
			||||||
    branches: [ master ]
 | 
					    branches: [ main ]
 | 
				
			||||||
  schedule:
 | 
					  schedule:
 | 
				
			||||||
    - cron: '30 6 * * 5'
 | 
					    - cron: '30 6 * * 5'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/workflows/ghcr-retention.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ghcr-retention.yml
									
									
									
									
										vendored
									
									
								
							@ -19,4 +19,4 @@ jobs:
 | 
				
			|||||||
          org-name: goauthentik
 | 
					          org-name: goauthentik
 | 
				
			||||||
          untagged-only: false
 | 
					          untagged-only: false
 | 
				
			||||||
          token: ${{ secrets.GHCR_CLEANUP_TOKEN }}
 | 
					          token: ${{ secrets.GHCR_CLEANUP_TOKEN }}
 | 
				
			||||||
          skip-tags: gh-next,gh-master
 | 
					          skip-tags: gh-next,gh-main
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										12
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							@ -30,9 +30,9 @@ jobs:
 | 
				
			|||||||
        with:
 | 
					        with:
 | 
				
			||||||
          push: ${{ github.event_name == 'release' }}
 | 
					          push: ${{ github.event_name == 'release' }}
 | 
				
			||||||
          tags: |
 | 
					          tags: |
 | 
				
			||||||
            beryju/authentik:2022.5.3,
 | 
					            beryju/authentik:2022.6.2,
 | 
				
			||||||
            beryju/authentik:latest,
 | 
					            beryju/authentik:latest,
 | 
				
			||||||
            ghcr.io/goauthentik/server:2022.5.3,
 | 
					            ghcr.io/goauthentik/server:2022.6.2,
 | 
				
			||||||
            ghcr.io/goauthentik/server:latest
 | 
					            ghcr.io/goauthentik/server:latest
 | 
				
			||||||
          platforms: linux/amd64,linux/arm64
 | 
					          platforms: linux/amd64,linux/arm64
 | 
				
			||||||
          context: .
 | 
					          context: .
 | 
				
			||||||
@ -69,9 +69,9 @@ jobs:
 | 
				
			|||||||
        with:
 | 
					        with:
 | 
				
			||||||
          push: ${{ github.event_name == 'release' }}
 | 
					          push: ${{ github.event_name == 'release' }}
 | 
				
			||||||
          tags: |
 | 
					          tags: |
 | 
				
			||||||
            beryju/authentik-${{ matrix.type }}:2022.5.3,
 | 
					            beryju/authentik-${{ matrix.type }}:2022.6.2,
 | 
				
			||||||
            beryju/authentik-${{ matrix.type }}:latest,
 | 
					            beryju/authentik-${{ matrix.type }}:latest,
 | 
				
			||||||
            ghcr.io/goauthentik/${{ matrix.type }}:2022.5.3,
 | 
					            ghcr.io/goauthentik/${{ matrix.type }}:2022.6.2,
 | 
				
			||||||
            ghcr.io/goauthentik/${{ matrix.type }}:latest
 | 
					            ghcr.io/goauthentik/${{ matrix.type }}:latest
 | 
				
			||||||
          file: ${{ matrix.type }}.Dockerfile
 | 
					          file: ${{ matrix.type }}.Dockerfile
 | 
				
			||||||
          platforms: linux/amd64,linux/arm64
 | 
					          platforms: linux/amd64,linux/arm64
 | 
				
			||||||
@ -91,7 +91,7 @@ jobs:
 | 
				
			|||||||
      - uses: actions/setup-go@v3
 | 
					      - uses: actions/setup-go@v3
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          go-version: "^1.17"
 | 
					          go-version: "^1.17"
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          cache: 'npm'
 | 
					          cache: 'npm'
 | 
				
			||||||
@ -152,7 +152,7 @@ jobs:
 | 
				
			|||||||
          SENTRY_PROJECT: authentik
 | 
					          SENTRY_PROJECT: authentik
 | 
				
			||||||
          SENTRY_URL: https://sentry.beryju.org
 | 
					          SENTRY_URL: https://sentry.beryju.org
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          version: authentik@2022.5.3
 | 
					          version: authentik@2022.6.2
 | 
				
			||||||
          environment: beryjuorg-prod
 | 
					          environment: beryjuorg-prod
 | 
				
			||||||
          sourcemaps: './web/dist'
 | 
					          sourcemaps: './web/dist'
 | 
				
			||||||
          url_prefix: '~/static/dist'
 | 
					          url_prefix: '~/static/dist'
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/workflows/translation-compile.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/translation-compile.yml
									
									
									
									
										vendored
									
									
								
							@ -1,7 +1,7 @@
 | 
				
			|||||||
name: authentik-backend-translate-compile
 | 
					name: authentik-backend-translate-compile
 | 
				
			||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches: [ master ]
 | 
					    branches: [ main ]
 | 
				
			||||||
    paths:
 | 
					    paths:
 | 
				
			||||||
      - '/locale/'
 | 
					      - '/locale/'
 | 
				
			||||||
  pull_request:
 | 
					  pull_request:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								.github/workflows/web-api-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/web-api-publish.yml
									
									
									
									
										vendored
									
									
								
							@ -1,7 +1,7 @@
 | 
				
			|||||||
name: authentik-web-api-publish
 | 
					name: authentik-web-api-publish
 | 
				
			||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches: [ master ]
 | 
					    branches: [ main ]
 | 
				
			||||||
    paths:
 | 
					    paths:
 | 
				
			||||||
      - 'schema.yml'
 | 
					      - 'schema.yml'
 | 
				
			||||||
  workflow_dispatch:
 | 
					  workflow_dispatch:
 | 
				
			||||||
@ -11,7 +11,7 @@ jobs:
 | 
				
			|||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - uses: actions/checkout@v3
 | 
					      - uses: actions/checkout@v3
 | 
				
			||||||
      # Setup .npmrc file to publish to npm
 | 
					      # Setup .npmrc file to publish to npm
 | 
				
			||||||
      - uses: actions/setup-node@v3.2.0
 | 
					      - uses: actions/setup-node@v3.3.0
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          node-version: '16'
 | 
					          node-version: '16'
 | 
				
			||||||
          registry-url: 'https://registry.npmjs.org'
 | 
					          registry-url: 'https://registry.npmjs.org'
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										1
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							@ -1,5 +1,6 @@
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
    "cSpell.words": [
 | 
					    "cSpell.words": [
 | 
				
			||||||
 | 
					        "akadmin",
 | 
				
			||||||
        "asgi",
 | 
					        "asgi",
 | 
				
			||||||
        "authentik",
 | 
					        "authentik",
 | 
				
			||||||
        "authn",
 | 
					        "authn",
 | 
				
			||||||
 | 
				
			|||||||
@ -60,7 +60,7 @@ representative at an online or offline event.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
 | 
					Instances of abusive, harassing, or otherwise unacceptable behavior may be
 | 
				
			||||||
reported to the community leaders responsible for enforcement at
 | 
					reported to the community leaders responsible for enforcement at
 | 
				
			||||||
hello@beryju.org.
 | 
					hello@goauthentik.io.
 | 
				
			||||||
All complaints will be reviewed and investigated promptly and fairly.
 | 
					All complaints will be reviewed and investigated promptly and fairly.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
All community leaders are obligated to respect the privacy and security of the
 | 
					All community leaders are obligated to respect the privacy and security of the
 | 
				
			||||||
 | 
				
			|||||||
@ -29,7 +29,7 @@ RUN pip install --no-cache-dir poetry && \
 | 
				
			|||||||
    poetry export -f requirements.txt --dev --output requirements-dev.txt
 | 
					    poetry export -f requirements.txt --dev --output requirements-dev.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Stage 4: Build go proxy
 | 
					# Stage 4: Build go proxy
 | 
				
			||||||
FROM docker.io/golang:1.18.2-bullseye AS builder
 | 
					FROM docker.io/golang:1.18.3-bullseye AS builder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
WORKDIR /work
 | 
					WORKDIR /work
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										5
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								Makefile
									
									
									
									
									
								
							@ -55,7 +55,7 @@ i18n-extract-core:
 | 
				
			|||||||
	./manage.py makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en
 | 
						./manage.py makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en
 | 
				
			||||||
 | 
					
 | 
				
			||||||
gen-build:
 | 
					gen-build:
 | 
				
			||||||
	./manage.py spectacular --file schema.yml
 | 
						AUTHENTIK_DEBUG=true ./manage.py spectacular --file schema.yml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
gen-clean:
 | 
					gen-clean:
 | 
				
			||||||
	rm -rf web/api/src/
 | 
						rm -rf web/api/src/
 | 
				
			||||||
@ -103,6 +103,9 @@ run:
 | 
				
			|||||||
## Web
 | 
					## Web
 | 
				
			||||||
#########################
 | 
					#########################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					web-build: web-install
 | 
				
			||||||
 | 
						cd web && npm run build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
web: web-lint-fix web-lint web-extract
 | 
					web: web-lint-fix web-lint web-extract
 | 
				
			||||||
 | 
					
 | 
				
			||||||
web-install:
 | 
					web-install:
 | 
				
			||||||
 | 
				
			|||||||
@ -9,7 +9,7 @@
 | 
				
			|||||||
[](https://github.com/goauthentik/authentik/actions/workflows/ci-outpost.yml)
 | 
					[](https://github.com/goauthentik/authentik/actions/workflows/ci-outpost.yml)
 | 
				
			||||||
[](https://github.com/goauthentik/authentik/actions/workflows/ci-web.yml)
 | 
					[](https://github.com/goauthentik/authentik/actions/workflows/ci-web.yml)
 | 
				
			||||||
[](https://codecov.io/gh/goauthentik/authentik)
 | 
					[](https://codecov.io/gh/goauthentik/authentik)
 | 
				
			||||||
[](https://goauthentik.testspace.com/)
 | 
					[](https://goauthentik.testspace.com/)
 | 
				
			||||||

 | 
					
 | 
				
			||||||

 | 
					
 | 
				
			||||||
[](https://www.transifex.com/beryjuorg/authentik/)
 | 
					[](https://www.transifex.com/beryjuorg/authentik/)
 | 
				
			||||||
 | 
				
			|||||||
@ -6,9 +6,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
| Version    | Supported          |
 | 
					| Version    | Supported          |
 | 
				
			||||||
| ---------- | ------------------ |
 | 
					| ---------- | ------------------ |
 | 
				
			||||||
| 2022.3.x   | :white_check_mark: |
 | 
					 | 
				
			||||||
| 2022.4.x   | :white_check_mark: |
 | 
					| 2022.4.x   | :white_check_mark: |
 | 
				
			||||||
 | 
					| 2022.5.x   | :white_check_mark: |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Reporting a Vulnerability
 | 
					## Reporting a Vulnerability
 | 
				
			||||||
 | 
					
 | 
				
			||||||
To report a vulnerability, send an email to [security@beryju.org](mailto:security@beryju.org)
 | 
					To report a vulnerability, send an email to [security@goauthentik.io](mailto:security@goauthentik.io)
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,7 @@
 | 
				
			|||||||
from os import environ
 | 
					from os import environ
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "2022.5.3"
 | 
					__version__ = "2022.6.2"
 | 
				
			||||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
 | 
					ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,4 @@ class AuthentikAdminConfig(AppConfig):
 | 
				
			|||||||
    verbose_name = "authentik Admin"
 | 
					    verbose_name = "authentik Admin"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def ready(self):
 | 
					    def ready(self):
 | 
				
			||||||
        from authentik.admin.tasks import clear_update_notifications
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        clear_update_notifications.delay()
 | 
					 | 
				
			||||||
        import_module("authentik.admin.signals")
 | 
					        import_module("authentik.admin.signals")
 | 
				
			||||||
 | 
				
			|||||||
@ -63,6 +63,7 @@ class ApplicationSerializer(ModelSerializer):
 | 
				
			|||||||
            "provider",
 | 
					            "provider",
 | 
				
			||||||
            "provider_obj",
 | 
					            "provider_obj",
 | 
				
			||||||
            "launch_url",
 | 
					            "launch_url",
 | 
				
			||||||
 | 
					            "open_in_new_tab",
 | 
				
			||||||
            "meta_launch_url",
 | 
					            "meta_launch_url",
 | 
				
			||||||
            "meta_icon",
 | 
					            "meta_icon",
 | 
				
			||||||
            "meta_description",
 | 
					            "meta_description",
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,7 @@ from rest_framework.decorators import action
 | 
				
			|||||||
from rest_framework.filters import OrderingFilter, SearchFilter
 | 
					from rest_framework.filters import OrderingFilter, SearchFilter
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
from rest_framework.serializers import ModelSerializer, SerializerMethodField
 | 
					from rest_framework.serializers import ModelSerializer, ReadOnlyField, SerializerMethodField
 | 
				
			||||||
from rest_framework.viewsets import GenericViewSet
 | 
					from rest_framework.viewsets import GenericViewSet
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -26,6 +26,7 @@ LOGGER = get_logger()
 | 
				
			|||||||
class SourceSerializer(ModelSerializer, MetaNameSerializer):
 | 
					class SourceSerializer(ModelSerializer, MetaNameSerializer):
 | 
				
			||||||
    """Source Serializer"""
 | 
					    """Source Serializer"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    managed = ReadOnlyField()
 | 
				
			||||||
    component = SerializerMethodField()
 | 
					    component = SerializerMethodField()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_component(self, obj: Source) -> str:
 | 
					    def get_component(self, obj: Source) -> str:
 | 
				
			||||||
@ -51,6 +52,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
 | 
				
			|||||||
            "meta_model_name",
 | 
					            "meta_model_name",
 | 
				
			||||||
            "policy_engine_mode",
 | 
					            "policy_engine_mode",
 | 
				
			||||||
            "user_matching_mode",
 | 
					            "user_matching_mode",
 | 
				
			||||||
 | 
					            "managed",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -67,6 +69,7 @@ class SourceViewSet(
 | 
				
			|||||||
    serializer_class = SourceSerializer
 | 
					    serializer_class = SourceSerializer
 | 
				
			||||||
    lookup_field = "slug"
 | 
					    lookup_field = "slug"
 | 
				
			||||||
    search_fields = ["slug", "name"]
 | 
					    search_fields = ["slug", "name"]
 | 
				
			||||||
 | 
					    filterset_fields = ["slug", "name", "managed"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_queryset(self):  # pragma: no cover
 | 
					    def get_queryset(self):  # pragma: no cover
 | 
				
			||||||
        return Source.objects.select_subclasses()
 | 
					        return Source.objects.select_subclasses()
 | 
				
			||||||
 | 
				
			|||||||
@ -43,7 +43,10 @@ from authentik.api.decorators import permission_required
 | 
				
			|||||||
from authentik.core.api.groups import GroupSerializer
 | 
					from authentik.core.api.groups import GroupSerializer
 | 
				
			||||||
from authentik.core.api.used_by import UsedByMixin
 | 
					from authentik.core.api.used_by import UsedByMixin
 | 
				
			||||||
from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
 | 
					from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
 | 
				
			||||||
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
 | 
					from authentik.core.middleware import (
 | 
				
			||||||
 | 
					    SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
 | 
				
			||||||
 | 
					    SESSION_KEY_IMPERSONATE_USER,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.core.models import (
 | 
					from authentik.core.models import (
 | 
				
			||||||
    USER_ATTRIBUTE_SA,
 | 
					    USER_ATTRIBUTE_SA,
 | 
				
			||||||
    USER_ATTRIBUTE_TOKEN_EXPIRING,
 | 
					    USER_ATTRIBUTE_TOKEN_EXPIRING,
 | 
				
			||||||
@ -336,11 +339,12 @@ class UserViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
        serializer = SessionUserSerializer(
 | 
					        serializer = SessionUserSerializer(
 | 
				
			||||||
            data={"user": UserSelfSerializer(instance=request.user, context=context).data}
 | 
					            data={"user": UserSelfSerializer(instance=request.user, context=context).data}
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if SESSION_IMPERSONATE_USER in request._request.session:
 | 
					        if SESSION_KEY_IMPERSONATE_USER in request._request.session:
 | 
				
			||||||
            serializer.initial_data["original"] = UserSelfSerializer(
 | 
					            serializer.initial_data["original"] = UserSelfSerializer(
 | 
				
			||||||
                instance=request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER],
 | 
					                instance=request._request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER],
 | 
				
			||||||
                context=context,
 | 
					                context=context,
 | 
				
			||||||
            ).data
 | 
					            ).data
 | 
				
			||||||
 | 
					        self.request.session.save()
 | 
				
			||||||
        return Response(serializer.initial_data)
 | 
					        return Response(serializer.initial_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @permission_required("authentik_core.reset_user_password")
 | 
					    @permission_required("authentik_core.reset_user_password")
 | 
				
			||||||
@ -367,7 +371,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
        except (ValidationError, IntegrityError) as exc:
 | 
					        except (ValidationError, IntegrityError) as exc:
 | 
				
			||||||
            LOGGER.debug("Failed to set password", exc=exc)
 | 
					            LOGGER.debug("Failed to set password", exc=exc)
 | 
				
			||||||
            return Response(status=400)
 | 
					            return Response(status=400)
 | 
				
			||||||
        if user.pk == request.user.pk and SESSION_IMPERSONATE_USER not in self.request.session:
 | 
					        if user.pk == request.user.pk and SESSION_KEY_IMPERSONATE_USER not in self.request.session:
 | 
				
			||||||
            LOGGER.debug("Updating session hash after password change")
 | 
					            LOGGER.debug("Updating session hash after password change")
 | 
				
			||||||
            update_session_auth_hash(self.request, user)
 | 
					            update_session_auth_hash(self.request, user)
 | 
				
			||||||
        return Response(status=204)
 | 
					        return Response(status=204)
 | 
				
			||||||
 | 
				
			|||||||
@ -2,10 +2,6 @@
 | 
				
			|||||||
from importlib import import_module
 | 
					from importlib import import_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.apps import AppConfig
 | 
					from django.apps import AppConfig
 | 
				
			||||||
from django.db import ProgrammingError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.core.signals import GAUGE_MODELS
 | 
					 | 
				
			||||||
from authentik.lib.utils.reflection import get_apps
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthentikCoreConfig(AppConfig):
 | 
					class AuthentikCoreConfig(AppConfig):
 | 
				
			||||||
@ -19,12 +15,3 @@ class AuthentikCoreConfig(AppConfig):
 | 
				
			|||||||
    def ready(self):
 | 
					    def ready(self):
 | 
				
			||||||
        import_module("authentik.core.signals")
 | 
					        import_module("authentik.core.signals")
 | 
				
			||||||
        import_module("authentik.core.managed")
 | 
					        import_module("authentik.core.managed")
 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            for app in get_apps():
 | 
					 | 
				
			||||||
                for model in app.get_models():
 | 
					 | 
				
			||||||
                    GAUGE_MODELS.labels(
 | 
					 | 
				
			||||||
                        model_name=model._meta.model_name,
 | 
					 | 
				
			||||||
                        app=model._meta.app_label,
 | 
					 | 
				
			||||||
                    ).set(model.objects.count())
 | 
					 | 
				
			||||||
        except ProgrammingError:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -12,5 +12,6 @@ class CoreManager(ObjectManager):
 | 
				
			|||||||
                Source,
 | 
					                Source,
 | 
				
			||||||
                "goauthentik.io/sources/inbuilt",
 | 
					                "goauthentik.io/sources/inbuilt",
 | 
				
			||||||
                name="authentik Built-in",
 | 
					                name="authentik Built-in",
 | 
				
			||||||
 | 
					                slug="authentik-built-in",
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
				
			|||||||
@ -7,8 +7,8 @@ from uuid import uuid4
 | 
				
			|||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from sentry_sdk.api import set_tag
 | 
					from sentry_sdk.api import set_tag
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SESSION_IMPERSONATE_USER = "authentik_impersonate_user"
 | 
					SESSION_KEY_IMPERSONATE_USER = "authentik/impersonate/user"
 | 
				
			||||||
SESSION_IMPERSONATE_ORIGINAL_USER = "authentik_impersonate_original_user"
 | 
					SESSION_KEY_IMPERSONATE_ORIGINAL_USER = "authentik/impersonate/original_user"
 | 
				
			||||||
LOCAL = local()
 | 
					LOCAL = local()
 | 
				
			||||||
RESPONSE_HEADER_ID = "X-authentik-id"
 | 
					RESPONSE_HEADER_ID = "X-authentik-id"
 | 
				
			||||||
KEY_AUTH_VIA = "auth_via"
 | 
					KEY_AUTH_VIA = "auth_via"
 | 
				
			||||||
@ -25,10 +25,10 @@ class ImpersonateMiddleware:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __call__(self, request: HttpRequest) -> HttpResponse:
 | 
					    def __call__(self, request: HttpRequest) -> HttpResponse:
 | 
				
			||||||
        # No permission checks are done here, they need to be checked before
 | 
					        # No permission checks are done here, they need to be checked before
 | 
				
			||||||
        # SESSION_IMPERSONATE_USER is set.
 | 
					        # SESSION_KEY_IMPERSONATE_USER is set.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if SESSION_IMPERSONATE_USER in request.session:
 | 
					        if SESSION_KEY_IMPERSONATE_USER in request.session:
 | 
				
			||||||
            request.user = request.session[SESSION_IMPERSONATE_USER]
 | 
					            request.user = request.session[SESSION_KEY_IMPERSONATE_USER]
 | 
				
			||||||
            # Ensure that the user is active, otherwise nothing will work
 | 
					            # Ensure that the user is active, otherwise nothing will work
 | 
				
			||||||
            request.user.is_active = True
 | 
					            request.user.is_active = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,8 +20,15 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			|||||||
    akadmin, _ = User.objects.using(db_alias).get_or_create(
 | 
					    akadmin, _ = User.objects.using(db_alias).get_or_create(
 | 
				
			||||||
        username="akadmin", email="root@localhost", name="authentik Default Admin"
 | 
					        username="akadmin", email="root@localhost", name="authentik Default Admin"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST:
 | 
					    password = None
 | 
				
			||||||
        akadmin.set_password(environ.get("AK_ADMIN_PASS", "akadmin"), signal=False)  # noqa # nosec
 | 
					    if "TF_BUILD" in environ or settings.TEST:
 | 
				
			||||||
 | 
					        password = "akadmin"  # noqa # nosec
 | 
				
			||||||
 | 
					    if "AK_ADMIN_PASS" in environ:
 | 
				
			||||||
 | 
					        password = environ["AK_ADMIN_PASS"]
 | 
				
			||||||
 | 
					    if "AUTHENTIK_BOOTSTRAP_PASSWORD" in environ:
 | 
				
			||||||
 | 
					        password = environ["AUTHENTIK_BOOTSTRAP_PASSWORD"]
 | 
				
			||||||
 | 
					    if password:
 | 
				
			||||||
 | 
					        akadmin.set_password(password, signal=False)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        akadmin.set_unusable_password()
 | 
					        akadmin.set_unusable_password()
 | 
				
			||||||
    akadmin.save()
 | 
					    akadmin.save()
 | 
				
			||||||
 | 
				
			|||||||
@ -16,8 +16,15 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			|||||||
    akadmin, _ = User.objects.using(db_alias).get_or_create(
 | 
					    akadmin, _ = User.objects.using(db_alias).get_or_create(
 | 
				
			||||||
        username="akadmin", email="root@localhost", name="authentik Default Admin"
 | 
					        username="akadmin", email="root@localhost", name="authentik Default Admin"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST:
 | 
					    password = None
 | 
				
			||||||
        akadmin.set_password(environ.get("AK_ADMIN_PASS", "akadmin"), signal=False)  # noqa # nosec
 | 
					    if "TF_BUILD" in environ or settings.TEST:
 | 
				
			||||||
 | 
					        password = "akadmin"  # noqa # nosec
 | 
				
			||||||
 | 
					    if "AK_ADMIN_PASS" in environ:
 | 
				
			||||||
 | 
					        password = environ["AK_ADMIN_PASS"]
 | 
				
			||||||
 | 
					    if "AUTHENTIK_BOOTSTRAP_PASSWORD" in environ:
 | 
				
			||||||
 | 
					        password = environ["AUTHENTIK_BOOTSTRAP_PASSWORD"]
 | 
				
			||||||
 | 
					    if password:
 | 
				
			||||||
 | 
					        akadmin.set_password(password, signal=False)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        akadmin.set_unusable_password()
 | 
					        akadmin.set_unusable_password()
 | 
				
			||||||
    akadmin.save()
 | 
					    akadmin.save()
 | 
				
			||||||
 | 
				
			|||||||
@ -44,14 +44,19 @@ def create_default_user_token(apps: Apps, schema_editor: BaseDatabaseSchemaEdito
 | 
				
			|||||||
    akadmin = User.objects.using(db_alias).filter(username="akadmin")
 | 
					    akadmin = User.objects.using(db_alias).filter(username="akadmin")
 | 
				
			||||||
    if not akadmin.exists():
 | 
					    if not akadmin.exists():
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    if "AK_ADMIN_TOKEN" not in environ:
 | 
					    key = None
 | 
				
			||||||
 | 
					    if "AK_ADMIN_TOKEN" in environ:
 | 
				
			||||||
 | 
					        key = environ["AK_ADMIN_TOKEN"]
 | 
				
			||||||
 | 
					    if "AUTHENTIK_BOOTSTRAP_TOKEN" in environ:
 | 
				
			||||||
 | 
					        key = environ["AUTHENTIK_BOOTSTRAP_TOKEN"]
 | 
				
			||||||
 | 
					    if not key:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    Token.objects.using(db_alias).create(
 | 
					    Token.objects.using(db_alias).create(
 | 
				
			||||||
        identifier="authentik-boostrap-token",
 | 
					        identifier="authentik-bootstrap-token",
 | 
				
			||||||
        user=akadmin.first(),
 | 
					        user=akadmin.first(),
 | 
				
			||||||
        intent=TokenIntents.INTENT_API,
 | 
					        intent=TokenIntents.INTENT_API,
 | 
				
			||||||
        expiring=False,
 | 
					        expiring=False,
 | 
				
			||||||
        key=environ["AK_ADMIN_TOKEN"],
 | 
					        key=key,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,20 @@
 | 
				
			|||||||
 | 
					# Generated by Django 4.0.5 on 2022-06-04 06:54
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_core", "0019_application_group"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="application",
 | 
				
			||||||
 | 
					            name="open_in_new_tab",
 | 
				
			||||||
 | 
					            field=models.BooleanField(
 | 
				
			||||||
 | 
					                default=False, help_text="Open launch URL in a new browser tab or window."
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -15,14 +15,19 @@ def create_default_user_token(apps: Apps, schema_editor: BaseDatabaseSchemaEdito
 | 
				
			|||||||
    akadmin = User.objects.using(db_alias).filter(username="akadmin")
 | 
					    akadmin = User.objects.using(db_alias).filter(username="akadmin")
 | 
				
			||||||
    if not akadmin.exists():
 | 
					    if not akadmin.exists():
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    if "AK_ADMIN_TOKEN" not in environ:
 | 
					    key = None
 | 
				
			||||||
 | 
					    if "AK_ADMIN_TOKEN" in environ:
 | 
				
			||||||
 | 
					        key = environ["AK_ADMIN_TOKEN"]
 | 
				
			||||||
 | 
					    if "AUTHENTIK_BOOTSTRAP_TOKEN" in environ:
 | 
				
			||||||
 | 
					        key = environ["AUTHENTIK_BOOTSTRAP_TOKEN"]
 | 
				
			||||||
 | 
					    if not key:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    Token.objects.using(db_alias).create(
 | 
					    Token.objects.using(db_alias).create(
 | 
				
			||||||
        identifier="authentik-boostrap-token",
 | 
					        identifier="authentik-bootstrap-token",
 | 
				
			||||||
        user=akadmin.first(),
 | 
					        user=akadmin.first(),
 | 
				
			||||||
        intent=TokenIntents.INTENT_API,
 | 
					        intent=TokenIntents.INTENT_API,
 | 
				
			||||||
        expiring=False,
 | 
					        expiring=False,
 | 
				
			||||||
        key=environ["AK_ADMIN_TOKEN"],
 | 
					        key=key,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -192,7 +192,7 @@ class User(GuardianUserMixin, AbstractUser):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def uid(self) -> str:
 | 
					    def uid(self) -> str:
 | 
				
			||||||
        """Generate a globall unique UID, based on the user ID and the hashed secret key"""
 | 
					        """Generate a globally unique UID, based on the user ID and the hashed secret key"""
 | 
				
			||||||
        return sha256(f"{self.id}-{settings.SECRET_KEY}".encode("ascii")).hexdigest()
 | 
					        return sha256(f"{self.id}-{settings.SECRET_KEY}".encode("ascii")).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
@ -278,6 +278,11 @@ class Application(PolicyBindingModel):
 | 
				
			|||||||
    meta_launch_url = models.TextField(
 | 
					    meta_launch_url = models.TextField(
 | 
				
			||||||
        default="", blank=True, validators=[DomainlessURLValidator()]
 | 
					        default="", blank=True, validators=[DomainlessURLValidator()]
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    open_in_new_tab = models.BooleanField(
 | 
				
			||||||
 | 
					        default=False, help_text=_("Open launch URL in a new browser tab or window.")
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # For template applications, this can be set to /static/authentik/applications/*
 | 
					    # For template applications, this can be set to /static/authentik/applications/*
 | 
				
			||||||
    meta_icon = models.FileField(
 | 
					    meta_icon = models.FileField(
 | 
				
			||||||
        upload_to="application-icons/",
 | 
					        upload_to="application-icons/",
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,6 @@
 | 
				
			|||||||
"""authentik core signals"""
 | 
					"""authentik core signals"""
 | 
				
			||||||
from typing import TYPE_CHECKING
 | 
					from typing import TYPE_CHECKING
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.apps import apps
 | 
					 | 
				
			||||||
from django.contrib.auth.signals import user_logged_in, user_logged_out
 | 
					from django.contrib.auth.signals import user_logged_in, user_logged_out
 | 
				
			||||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
 | 
					from django.contrib.sessions.backends.cache import KEY_PREFIX
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
@ -10,30 +9,16 @@ from django.db.models import Model
 | 
				
			|||||||
from django.db.models.signals import post_save, pre_delete
 | 
					from django.db.models.signals import post_save, pre_delete
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
from django.http.request import HttpRequest
 | 
					from django.http.request import HttpRequest
 | 
				
			||||||
from prometheus_client import Gauge
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.root.monitoring import monitoring_set
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Arguments: user: User, password: str
 | 
					# Arguments: user: User, password: str
 | 
				
			||||||
password_changed = Signal()
 | 
					password_changed = Signal()
 | 
				
			||||||
 | 
					# Arguments: credentials: dict[str, any], request: HttpRequest, stage: Stage
 | 
				
			||||||
GAUGE_MODELS = Gauge("authentik_models", "Count of various objects", ["model_name", "app"])
 | 
					login_failed = Signal()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TYPE_CHECKING:
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
    from authentik.core.models import AuthenticatedSession, User
 | 
					    from authentik.core.models import AuthenticatedSession, User
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(monitoring_set)
 | 
					 | 
				
			||||||
# pylint: disable=unused-argument
 | 
					 | 
				
			||||||
def monitoring_set_models(sender, **kwargs):
 | 
					 | 
				
			||||||
    """set models gauges"""
 | 
					 | 
				
			||||||
    for model in apps.get_models():
 | 
					 | 
				
			||||||
        GAUGE_MODELS.labels(
 | 
					 | 
				
			||||||
            model_name=model._meta.model_name,
 | 
					 | 
				
			||||||
            app=model._meta.app_label,
 | 
					 | 
				
			||||||
        ).set(model.objects.count())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@receiver(post_save)
 | 
					@receiver(post_save)
 | 
				
			||||||
# pylint: disable=unused-argument
 | 
					# pylint: disable=unused-argument
 | 
				
			||||||
def post_save_application(sender: type[Model], instance, created: bool, **_):
 | 
					def post_save_application(sender: type[Model], instance, created: bool, **_):
 | 
				
			||||||
 | 
				
			|||||||
@ -5,6 +5,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
{% block head_before %}
 | 
					{% block head_before %}
 | 
				
			||||||
{{ block.super }}
 | 
					{{ block.super }}
 | 
				
			||||||
 | 
					<link rel="prefetch" href="{{ flow.background_url }}" />
 | 
				
			||||||
{% if flow.compatibility_mode and not inspector %}
 | 
					{% if flow.compatibility_mode and not inspector %}
 | 
				
			||||||
<script>ShadyDOM = { force: !navigator.webdriver };</script>
 | 
					<script>ShadyDOM = { force: !navigator.webdriver };</script>
 | 
				
			||||||
{% endif %}
 | 
					{% endif %}
 | 
				
			||||||
@ -19,7 +20,7 @@ window.authentik.flow = {
 | 
				
			|||||||
{% block head %}
 | 
					{% block head %}
 | 
				
			||||||
<script src="{% static 'dist/flow/FlowInterface.js' %}" type="module"></script>
 | 
					<script src="{% static 'dist/flow/FlowInterface.js' %}" type="module"></script>
 | 
				
			||||||
<style>
 | 
					<style>
 | 
				
			||||||
.pf-c-background-image::before {
 | 
					:root {
 | 
				
			||||||
    --ak-flow-background: url("{{ flow.background_url }}");
 | 
					    --ak-flow-background: url("{{ flow.background_url }}");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
</style>
 | 
					</style>
 | 
				
			||||||
 | 
				
			|||||||
@ -4,13 +4,19 @@
 | 
				
			|||||||
{% load i18n %}
 | 
					{% load i18n %}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
{% block head_before %}
 | 
					{% block head_before %}
 | 
				
			||||||
 | 
					<link rel="prefetch" href="/static/dist/assets/images/flow_background.jpg" />
 | 
				
			||||||
<link rel="stylesheet" type="text/css" href="{% static 'dist/patternfly.min.css' %}">
 | 
					<link rel="stylesheet" type="text/css" href="{% static 'dist/patternfly.min.css' %}">
 | 
				
			||||||
{% endblock %}
 | 
					{% endblock %}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
{% block head %}
 | 
					{% block head %}
 | 
				
			||||||
<style>
 | 
					<style>
 | 
				
			||||||
.pf-c-background-image::before {
 | 
					:root {
 | 
				
			||||||
    --ak-flow-background: url("/static/dist/assets/images/flow_background.jpg");
 | 
					    --ak-flow-background: url("/static/dist/assets/images/flow_background.jpg");
 | 
				
			||||||
 | 
					    --pf-c-background-image--BackgroundImage: var(--ak-flow-background);
 | 
				
			||||||
 | 
					    --pf-c-background-image--BackgroundImage-2x: var(--ak-flow-background);
 | 
				
			||||||
 | 
					    --pf-c-background-image--BackgroundImage--sm: var(--ak-flow-background);
 | 
				
			||||||
 | 
					    --pf-c-background-image--BackgroundImage--sm-2x: var(--ak-flow-background);
 | 
				
			||||||
 | 
					    --pf-c-background-image--BackgroundImage--lg: var(--ak-flow-background);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
/* Form with user */
 | 
					/* Form with user */
 | 
				
			||||||
.form-control-static {
 | 
					.form-control-static {
 | 
				
			||||||
 | 
				
			|||||||
@ -29,6 +29,7 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
            name="allowed",
 | 
					            name="allowed",
 | 
				
			||||||
            slug="allowed",
 | 
					            slug="allowed",
 | 
				
			||||||
            meta_launch_url="https://goauthentik.io/%(username)s",
 | 
					            meta_launch_url="https://goauthentik.io/%(username)s",
 | 
				
			||||||
 | 
					            open_in_new_tab=True,
 | 
				
			||||||
            provider=self.provider,
 | 
					            provider=self.provider,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.denied = Application.objects.create(name="denied", slug="denied")
 | 
					        self.denied = Application.objects.create(name="denied", slug="denied")
 | 
				
			||||||
@ -100,6 +101,7 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
                        },
 | 
					                        },
 | 
				
			||||||
                        "launch_url": f"https://goauthentik.io/{self.user.username}",
 | 
					                        "launch_url": f"https://goauthentik.io/{self.user.username}",
 | 
				
			||||||
                        "meta_launch_url": "https://goauthentik.io/%(username)s",
 | 
					                        "meta_launch_url": "https://goauthentik.io/%(username)s",
 | 
				
			||||||
 | 
					                        "open_in_new_tab": True,
 | 
				
			||||||
                        "meta_icon": None,
 | 
					                        "meta_icon": None,
 | 
				
			||||||
                        "meta_description": "",
 | 
					                        "meta_description": "",
 | 
				
			||||||
                        "meta_publisher": "",
 | 
					                        "meta_publisher": "",
 | 
				
			||||||
@ -148,6 +150,7 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
                        },
 | 
					                        },
 | 
				
			||||||
                        "launch_url": f"https://goauthentik.io/{self.user.username}",
 | 
					                        "launch_url": f"https://goauthentik.io/{self.user.username}",
 | 
				
			||||||
                        "meta_launch_url": "https://goauthentik.io/%(username)s",
 | 
					                        "meta_launch_url": "https://goauthentik.io/%(username)s",
 | 
				
			||||||
 | 
					                        "open_in_new_tab": True,
 | 
				
			||||||
                        "meta_icon": None,
 | 
					                        "meta_icon": None,
 | 
				
			||||||
                        "meta_description": "",
 | 
					                        "meta_description": "",
 | 
				
			||||||
                        "meta_publisher": "",
 | 
					                        "meta_publisher": "",
 | 
				
			||||||
@ -158,6 +161,7 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
                        "meta_description": "",
 | 
					                        "meta_description": "",
 | 
				
			||||||
                        "meta_icon": None,
 | 
					                        "meta_icon": None,
 | 
				
			||||||
                        "meta_launch_url": "",
 | 
					                        "meta_launch_url": "",
 | 
				
			||||||
 | 
					                        "open_in_new_tab": False,
 | 
				
			||||||
                        "meta_publisher": "",
 | 
					                        "meta_publisher": "",
 | 
				
			||||||
                        "group": "",
 | 
					                        "group": "",
 | 
				
			||||||
                        "name": "denied",
 | 
					                        "name": "denied",
 | 
				
			||||||
 | 
				
			|||||||
@ -5,7 +5,10 @@ from django.shortcuts import get_object_or_404, redirect
 | 
				
			|||||||
from django.views import View
 | 
					from django.views import View
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
 | 
					from authentik.core.middleware import (
 | 
				
			||||||
 | 
					    SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
 | 
				
			||||||
 | 
					    SESSION_KEY_IMPERSONATE_USER,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.models import User
 | 
				
			||||||
from authentik.events.models import Event, EventAction
 | 
					from authentik.events.models import Event, EventAction
 | 
				
			||||||
from authentik.lib.config import CONFIG
 | 
					from authentik.lib.config import CONFIG
 | 
				
			||||||
@ -27,8 +30,8 @@ class ImpersonateInitView(View):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        user_to_be = get_object_or_404(User, pk=user_id)
 | 
					        user_to_be = get_object_or_404(User, pk=user_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        request.session[SESSION_IMPERSONATE_ORIGINAL_USER] = request.user
 | 
					        request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER] = request.user
 | 
				
			||||||
        request.session[SESSION_IMPERSONATE_USER] = user_to_be
 | 
					        request.session[SESSION_KEY_IMPERSONATE_USER] = user_to_be
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Event.new(EventAction.IMPERSONATION_STARTED).from_http(request, user_to_be)
 | 
					        Event.new(EventAction.IMPERSONATION_STARTED).from_http(request, user_to_be)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -41,16 +44,16 @@ class ImpersonateEndView(View):
 | 
				
			|||||||
    def get(self, request: HttpRequest) -> HttpResponse:
 | 
					    def get(self, request: HttpRequest) -> HttpResponse:
 | 
				
			||||||
        """End Impersonation handler"""
 | 
					        """End Impersonation handler"""
 | 
				
			||||||
        if (
 | 
					        if (
 | 
				
			||||||
            SESSION_IMPERSONATE_USER not in request.session
 | 
					            SESSION_KEY_IMPERSONATE_USER not in request.session
 | 
				
			||||||
            or SESSION_IMPERSONATE_ORIGINAL_USER not in request.session
 | 
					            or SESSION_KEY_IMPERSONATE_ORIGINAL_USER not in request.session
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            LOGGER.debug("Can't end impersonation", user=request.user)
 | 
					            LOGGER.debug("Can't end impersonation", user=request.user)
 | 
				
			||||||
            return redirect("authentik_core:if-user")
 | 
					            return redirect("authentik_core:if-user")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        original_user = request.session[SESSION_IMPERSONATE_ORIGINAL_USER]
 | 
					        original_user = request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        del request.session[SESSION_IMPERSONATE_USER]
 | 
					        del request.session[SESSION_KEY_IMPERSONATE_USER]
 | 
				
			||||||
        del request.session[SESSION_IMPERSONATE_ORIGINAL_USER]
 | 
					        del request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Event.new(EventAction.IMPERSONATION_ENDED).from_http(request, original_user)
 | 
					        Event.new(EventAction.IMPERSONATION_ENDED).from_http(request, original_user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -76,11 +76,8 @@ class GeoIPReader:
 | 
				
			|||||||
            except (GeoIP2Error, ValueError):
 | 
					            except (GeoIP2Error, ValueError):
 | 
				
			||||||
                return None
 | 
					                return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
 | 
					    def city_to_dict(self, city: City) -> GeoIPDict:
 | 
				
			||||||
        """Wrapper for self.city that returns a dict"""
 | 
					        """Convert City to dict"""
 | 
				
			||||||
        city = self.city(ip_address)
 | 
					 | 
				
			||||||
        if not city:
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
        city_dict: GeoIPDict = {
 | 
					        city_dict: GeoIPDict = {
 | 
				
			||||||
            "continent": city.continent.code,
 | 
					            "continent": city.continent.code,
 | 
				
			||||||
            "country": city.country.iso_code,
 | 
					            "country": city.country.iso_code,
 | 
				
			||||||
@ -92,5 +89,12 @@ class GeoIPReader:
 | 
				
			|||||||
            city_dict["city"] = city.city.name
 | 
					            city_dict["city"] = city.city.name
 | 
				
			||||||
        return city_dict
 | 
					        return city_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
 | 
				
			||||||
 | 
					        """Wrapper for self.city that returns a dict"""
 | 
				
			||||||
 | 
					        city = self.city(ip_address)
 | 
				
			||||||
 | 
					        if not city:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        return self.city_to_dict(city)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
GEOIP_READER = GeoIPReader()
 | 
					GEOIP_READER = GeoIPReader()
 | 
				
			||||||
 | 
				
			|||||||
@ -3,6 +3,7 @@ from functools import partial
 | 
				
			|||||||
from typing import Callable
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.contrib.sessions.models import Session
 | 
				
			||||||
from django.core.exceptions import SuspiciousOperation
 | 
					from django.core.exceptions import SuspiciousOperation
 | 
				
			||||||
from django.db.models import Model
 | 
					from django.db.models import Model
 | 
				
			||||||
from django.db.models.signals import post_save, pre_delete
 | 
					from django.db.models.signals import post_save, pre_delete
 | 
				
			||||||
@ -24,6 +25,7 @@ IGNORED_MODELS = [
 | 
				
			|||||||
    UserObjectPermission,
 | 
					    UserObjectPermission,
 | 
				
			||||||
    AuthenticatedSession,
 | 
					    AuthenticatedSession,
 | 
				
			||||||
    StaticToken,
 | 
					    StaticToken,
 | 
				
			||||||
 | 
					    Session,
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
if settings.DEBUG:
 | 
					if settings.DEBUG:
 | 
				
			||||||
    from silk.models import Request, Response, SQLQuery
 | 
					    from silk.models import Request, Response, SQLQuery
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,50 @@
 | 
				
			|||||||
 | 
					# Generated by Django 4.0.4 on 2022-05-30 18:08
 | 
				
			||||||
 | 
					from django.apps.registry import Apps
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from authentik.events.models import TransportMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def notify_local_transport(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
				
			||||||
 | 
					    db_alias = schema_editor.connection.alias
 | 
				
			||||||
 | 
					    NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
 | 
				
			||||||
 | 
					    NotificationRule = apps.get_model("authentik_events", "NotificationRule")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    local_transport, _ = NotificationTransport.objects.using(db_alias).update_or_create(
 | 
				
			||||||
 | 
					        name="default-local-transport",
 | 
				
			||||||
 | 
					        defaults={"mode": TransportMode.LOCAL},
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for trigger in NotificationRule.objects.using(db_alias).filter(
 | 
				
			||||||
 | 
					        name__in=[
 | 
				
			||||||
 | 
					            "default-notify-configuration-error",
 | 
				
			||||||
 | 
					            "default-notify-exception",
 | 
				
			||||||
 | 
					            "default-notify-update",
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        trigger.transports.add(local_transport)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_events", "0001_squashed_0019_alter_notificationtransport_webhook_url"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.AlterField(
 | 
				
			||||||
 | 
					            model_name="notificationtransport",
 | 
				
			||||||
 | 
					            name="mode",
 | 
				
			||||||
 | 
					            field=models.TextField(
 | 
				
			||||||
 | 
					                choices=[
 | 
				
			||||||
 | 
					                    ("local", "authentik inbuilt notifications"),
 | 
				
			||||||
 | 
					                    ("webhook", "Generic Webhook"),
 | 
				
			||||||
 | 
					                    ("webhook_slack", "Slack Webhook (Slack/Discord)"),
 | 
				
			||||||
 | 
					                    ("email", "Email"),
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                default="local",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.RunPython(notify_local_transport),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -23,7 +23,10 @@ from requests import RequestException
 | 
				
			|||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik import __version__
 | 
					from authentik import __version__
 | 
				
			||||||
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
 | 
					from authentik.core.middleware import (
 | 
				
			||||||
 | 
					    SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
 | 
				
			||||||
 | 
					    SESSION_KEY_IMPERSONATE_USER,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
 | 
					from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
 | 
				
			||||||
from authentik.events.geo import GEOIP_READER
 | 
					from authentik.events.geo import GEOIP_READER
 | 
				
			||||||
from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict
 | 
					from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict
 | 
				
			||||||
@ -233,15 +236,15 @@ class Event(ExpiringModel):
 | 
				
			|||||||
        if hasattr(request, "user"):
 | 
					        if hasattr(request, "user"):
 | 
				
			||||||
            original_user = None
 | 
					            original_user = None
 | 
				
			||||||
            if hasattr(request, "session"):
 | 
					            if hasattr(request, "session"):
 | 
				
			||||||
                original_user = request.session.get(SESSION_IMPERSONATE_ORIGINAL_USER, None)
 | 
					                original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None)
 | 
				
			||||||
            self.user = get_user(request.user, original_user)
 | 
					            self.user = get_user(request.user, original_user)
 | 
				
			||||||
        if user:
 | 
					        if user:
 | 
				
			||||||
            self.user = get_user(user)
 | 
					            self.user = get_user(user)
 | 
				
			||||||
        # Check if we're currently impersonating, and add that user
 | 
					        # Check if we're currently impersonating, and add that user
 | 
				
			||||||
        if hasattr(request, "session"):
 | 
					        if hasattr(request, "session"):
 | 
				
			||||||
            if SESSION_IMPERSONATE_ORIGINAL_USER in request.session:
 | 
					            if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session:
 | 
				
			||||||
                self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER])
 | 
					                self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
 | 
				
			||||||
                self.user["on_behalf_of"] = get_user(request.session[SESSION_IMPERSONATE_USER])
 | 
					                self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
 | 
				
			||||||
        # User 255.255.255.255 as fallback if IP cannot be determined
 | 
					        # User 255.255.255.255 as fallback if IP cannot be determined
 | 
				
			||||||
        self.client_ip = get_client_ip(request)
 | 
					        self.client_ip = get_client_ip(request)
 | 
				
			||||||
        # Apply GeoIP Data, when enabled
 | 
					        # Apply GeoIP Data, when enabled
 | 
				
			||||||
@ -289,6 +292,7 @@ class Event(ExpiringModel):
 | 
				
			|||||||
class TransportMode(models.TextChoices):
 | 
					class TransportMode(models.TextChoices):
 | 
				
			||||||
    """Modes that a notification transport can send a notification"""
 | 
					    """Modes that a notification transport can send a notification"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    LOCAL = "local", _("authentik inbuilt notifications")
 | 
				
			||||||
    WEBHOOK = "webhook", _("Generic Webhook")
 | 
					    WEBHOOK = "webhook", _("Generic Webhook")
 | 
				
			||||||
    WEBHOOK_SLACK = "webhook_slack", _("Slack Webhook (Slack/Discord)")
 | 
					    WEBHOOK_SLACK = "webhook_slack", _("Slack Webhook (Slack/Discord)")
 | 
				
			||||||
    EMAIL = "email", _("Email")
 | 
					    EMAIL = "email", _("Email")
 | 
				
			||||||
@ -300,7 +304,7 @@ class NotificationTransport(models.Model):
 | 
				
			|||||||
    uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
 | 
					    uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    name = models.TextField(unique=True)
 | 
					    name = models.TextField(unique=True)
 | 
				
			||||||
    mode = models.TextField(choices=TransportMode.choices)
 | 
					    mode = models.TextField(choices=TransportMode.choices, default=TransportMode.LOCAL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    webhook_url = models.TextField(blank=True, validators=[DomainlessURLValidator()])
 | 
					    webhook_url = models.TextField(blank=True, validators=[DomainlessURLValidator()])
 | 
				
			||||||
    webhook_mapping = models.ForeignKey(
 | 
					    webhook_mapping = models.ForeignKey(
 | 
				
			||||||
@ -315,6 +319,8 @@ class NotificationTransport(models.Model):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def send(self, notification: "Notification") -> list[str]:
 | 
					    def send(self, notification: "Notification") -> list[str]:
 | 
				
			||||||
        """Send notification to user, called from async task"""
 | 
					        """Send notification to user, called from async task"""
 | 
				
			||||||
 | 
					        if self.mode == TransportMode.LOCAL:
 | 
				
			||||||
 | 
					            return self.send_local(notification)
 | 
				
			||||||
        if self.mode == TransportMode.WEBHOOK:
 | 
					        if self.mode == TransportMode.WEBHOOK:
 | 
				
			||||||
            return self.send_webhook(notification)
 | 
					            return self.send_webhook(notification)
 | 
				
			||||||
        if self.mode == TransportMode.WEBHOOK_SLACK:
 | 
					        if self.mode == TransportMode.WEBHOOK_SLACK:
 | 
				
			||||||
@ -323,6 +329,17 @@ class NotificationTransport(models.Model):
 | 
				
			|||||||
            return self.send_email(notification)
 | 
					            return self.send_email(notification)
 | 
				
			||||||
        raise ValueError(f"Invalid mode {self.mode} set")
 | 
					        raise ValueError(f"Invalid mode {self.mode} set")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_local(self, notification: "Notification") -> list[str]:
 | 
				
			||||||
 | 
					        """Local notification delivery"""
 | 
				
			||||||
 | 
					        if self.webhook_mapping:
 | 
				
			||||||
 | 
					            self.webhook_mapping.evaluate(
 | 
				
			||||||
 | 
					                user=notification.user,
 | 
				
			||||||
 | 
					                request=None,
 | 
				
			||||||
 | 
					                notification=notification,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        notification.save()
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def send_webhook(self, notification: "Notification") -> list[str]:
 | 
					    def send_webhook(self, notification: "Notification") -> list[str]:
 | 
				
			||||||
        """Send notification to generic webhook"""
 | 
					        """Send notification to generic webhook"""
 | 
				
			||||||
        default_body = {
 | 
					        default_body = {
 | 
				
			||||||
 | 
				
			|||||||
@ -2,15 +2,16 @@
 | 
				
			|||||||
from threading import Thread
 | 
					from threading import Thread
 | 
				
			||||||
from typing import Any, Optional
 | 
					from typing import Any, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
 | 
					from django.contrib.auth.signals import user_logged_in, user_logged_out
 | 
				
			||||||
from django.db.models.signals import post_save, pre_delete
 | 
					from django.db.models.signals import post_save, pre_delete
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.models import User
 | 
				
			||||||
from authentik.core.signals import password_changed
 | 
					from authentik.core.signals import login_failed, password_changed
 | 
				
			||||||
from authentik.events.models import Event, EventAction
 | 
					from authentik.events.models import Event, EventAction
 | 
				
			||||||
from authentik.events.tasks import event_notification_handler, gdpr_cleanup
 | 
					from authentik.events.tasks import event_notification_handler, gdpr_cleanup
 | 
				
			||||||
 | 
					from authentik.flows.models import Stage
 | 
				
			||||||
from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
 | 
					from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
 | 
				
			||||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
 | 
					from authentik.flows.views.executor import SESSION_KEY_PLAN
 | 
				
			||||||
from authentik.stages.invitation.models import Invitation
 | 
					from authentik.stages.invitation.models import Invitation
 | 
				
			||||||
@ -77,11 +78,18 @@ def on_user_write(sender, request: HttpRequest, user: User, data: dict[str, Any]
 | 
				
			|||||||
    thread.run()
 | 
					    thread.run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(user_login_failed)
 | 
					@receiver(login_failed)
 | 
				
			||||||
# pylint: disable=unused-argument
 | 
					# pylint: disable=unused-argument
 | 
				
			||||||
def on_user_login_failed(sender, credentials: dict[str, str], request: HttpRequest, **_):
 | 
					def on_login_failed(
 | 
				
			||||||
    """Failed Login"""
 | 
					    signal,
 | 
				
			||||||
    thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)
 | 
					    sender,
 | 
				
			||||||
 | 
					    credentials: dict[str, str],
 | 
				
			||||||
 | 
					    request: HttpRequest,
 | 
				
			||||||
 | 
					    stage: Optional[Stage] = None,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """Failed Login, authentik custom event"""
 | 
				
			||||||
 | 
					    thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials, stage=stage, **kwargs)
 | 
				
			||||||
    thread.run()
 | 
					    thread.run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,11 @@
 | 
				
			|||||||
"""Event notification tasks"""
 | 
					"""Event notification tasks"""
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.db.models.query_utils import Q
 | 
					from django.db.models.query_utils import Q
 | 
				
			||||||
from guardian.shortcuts import get_anonymous_user
 | 
					from guardian.shortcuts import get_anonymous_user
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from authentik.core.exceptions import PropertyMappingExpressionException
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.models import User
 | 
				
			||||||
from authentik.events.models import (
 | 
					from authentik.events.models import (
 | 
				
			||||||
    Event,
 | 
					    Event,
 | 
				
			||||||
@ -39,10 +42,9 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
 | 
				
			|||||||
        LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid)
 | 
					        LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid)
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    event: Event = events.first()
 | 
					    event: Event = events.first()
 | 
				
			||||||
    triggers: NotificationRule = NotificationRule.objects.filter(name=trigger_name)
 | 
					    trigger: Optional[NotificationRule] = NotificationRule.objects.filter(name=trigger_name).first()
 | 
				
			||||||
    if not triggers.exists():
 | 
					    if not trigger:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    trigger = triggers.first()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if "policy_uuid" in event.context:
 | 
					    if "policy_uuid" in event.context:
 | 
				
			||||||
        policy_uuid = event.context["policy_uuid"]
 | 
					        policy_uuid = event.context["policy_uuid"]
 | 
				
			||||||
@ -81,11 +83,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
 | 
				
			|||||||
    for transport in trigger.transports.all():
 | 
					    for transport in trigger.transports.all():
 | 
				
			||||||
        for user in trigger.group.users.all():
 | 
					        for user in trigger.group.users.all():
 | 
				
			||||||
            LOGGER.debug("created notification")
 | 
					            LOGGER.debug("created notification")
 | 
				
			||||||
            notification = Notification.objects.create(
 | 
					 | 
				
			||||||
                severity=trigger.severity, body=event.summary, event=event, user=user
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            notification_transport.apply_async(
 | 
					            notification_transport.apply_async(
 | 
				
			||||||
                args=[notification.pk, transport.pk], queue="authentik_events"
 | 
					                args=[
 | 
				
			||||||
 | 
					                    transport.pk,
 | 
				
			||||||
 | 
					                    str(event.pk),
 | 
				
			||||||
 | 
					                    user.pk,
 | 
				
			||||||
 | 
					                    str(trigger.pk),
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                queue="authentik_events",
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            if transport.send_once:
 | 
					            if transport.send_once:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
@ -97,19 +102,30 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
 | 
				
			|||||||
    retry_backoff=True,
 | 
					    retry_backoff=True,
 | 
				
			||||||
    base=MonitoredTask,
 | 
					    base=MonitoredTask,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def notification_transport(self: MonitoredTask, notification_pk: int, transport_pk: int):
 | 
					def notification_transport(
 | 
				
			||||||
 | 
					    self: MonitoredTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
    """Send notification over specified transport"""
 | 
					    """Send notification over specified transport"""
 | 
				
			||||||
    self.save_on_success = False
 | 
					    self.save_on_success = False
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        notification: Notification = Notification.objects.filter(pk=notification_pk).first()
 | 
					        event = Event.objects.filter(pk=event_pk).first()
 | 
				
			||||||
        if not notification:
 | 
					        if not event:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					        user = User.objects.filter(pk=user_pk).first()
 | 
				
			||||||
 | 
					        if not user:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        trigger = NotificationRule.objects.filter(pk=trigger_pk).first()
 | 
				
			||||||
 | 
					        if not trigger:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        notification = Notification(
 | 
				
			||||||
 | 
					            severity=trigger.severity, body=event.summary, event=event, user=user
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        transport = NotificationTransport.objects.filter(pk=transport_pk).first()
 | 
					        transport = NotificationTransport.objects.filter(pk=transport_pk).first()
 | 
				
			||||||
        if not transport:
 | 
					        if not transport:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        transport.send(notification)
 | 
					        transport.send(notification)
 | 
				
			||||||
        self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
 | 
					        self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
 | 
				
			||||||
    except NotificationTransportError as exc:
 | 
					    except (NotificationTransportError, PropertyMappingExpressionException) as exc:
 | 
				
			||||||
        self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
 | 
					        self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
 | 
				
			||||||
        raise exc
 | 
					        raise exc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,10 @@ from authentik.events.models import (
 | 
				
			|||||||
    Notification,
 | 
					    Notification,
 | 
				
			||||||
    NotificationRule,
 | 
					    NotificationRule,
 | 
				
			||||||
    NotificationTransport,
 | 
					    NotificationTransport,
 | 
				
			||||||
 | 
					    NotificationWebhookMapping,
 | 
				
			||||||
 | 
					    TransportMode,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.policies.event_matcher.models import EventMatcherPolicy
 | 
					from authentik.policies.event_matcher.models import EventMatcherPolicy
 | 
				
			||||||
from authentik.policies.exceptions import PolicyException
 | 
					from authentik.policies.exceptions import PolicyException
 | 
				
			||||||
from authentik.policies.models import PolicyBinding
 | 
					from authentik.policies.models import PolicyBinding
 | 
				
			||||||
@ -105,4 +108,26 @@ class TestEventsNotifications(TestCase):
 | 
				
			|||||||
        execute_mock = MagicMock()
 | 
					        execute_mock = MagicMock()
 | 
				
			||||||
        with patch("authentik.events.models.NotificationTransport.send", execute_mock):
 | 
					        with patch("authentik.events.models.NotificationTransport.send", execute_mock):
 | 
				
			||||||
            Event.new(EventAction.CUSTOM_PREFIX).save()
 | 
					            Event.new(EventAction.CUSTOM_PREFIX).save()
 | 
				
			||||||
        self.assertEqual(Notification.objects.count(), 1)
 | 
					        self.assertEqual(execute_mock.call_count, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_transport_mapping(self):
 | 
				
			||||||
 | 
					        """Test transport mapping"""
 | 
				
			||||||
 | 
					        mapping = NotificationWebhookMapping.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            expression="""notification.body = 'foo'""",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        transport = NotificationTransport.objects.create(
 | 
				
			||||||
 | 
					            name="transport", webhook_mapping=mapping, mode=TransportMode.LOCAL
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        NotificationRule.objects.filter(name__startswith="default").delete()
 | 
				
			||||||
 | 
					        trigger = NotificationRule.objects.create(name="trigger", group=self.group)
 | 
				
			||||||
 | 
					        trigger.transports.add(transport)
 | 
				
			||||||
 | 
					        matcher = EventMatcherPolicy.objects.create(
 | 
				
			||||||
 | 
					            name="matcher", action=EventAction.CUSTOM_PREFIX
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        PolicyBinding.objects.create(target=trigger, policy=matcher, order=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Notification.objects.all().delete()
 | 
				
			||||||
 | 
					        Event.new(EventAction.CUSTOM_PREFIX).save()
 | 
				
			||||||
 | 
					        self.assertEqual(Notification.objects.first().body, "foo")
 | 
				
			||||||
 | 
				
			|||||||
@ -10,9 +10,11 @@ from django.db import models
 | 
				
			|||||||
from django.db.models.base import Model
 | 
					from django.db.models.base import Model
 | 
				
			||||||
from django.http.request import HttpRequest
 | 
					from django.http.request import HttpRequest
 | 
				
			||||||
from django.views.debug import SafeExceptionReporterFilter
 | 
					from django.views.debug import SafeExceptionReporterFilter
 | 
				
			||||||
 | 
					from geoip2.models import City
 | 
				
			||||||
from guardian.utils import get_anonymous_user
 | 
					from guardian.utils import get_anonymous_user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.models import User
 | 
				
			||||||
 | 
					from authentik.events.geo import GEOIP_READER
 | 
				
			||||||
from authentik.policies.types import PolicyRequest
 | 
					from authentik.policies.types import PolicyRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Special keys which are *not* cleaned, even when the default filter
 | 
					# Special keys which are *not* cleaned, even when the default filter
 | 
				
			||||||
@ -93,6 +95,8 @@ def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
 | 
				
			|||||||
            final_dict[key] = value.hex
 | 
					            final_dict[key] = value.hex
 | 
				
			||||||
        elif isinstance(value, (HttpRequest, WSGIRequest)):
 | 
					        elif isinstance(value, (HttpRequest, WSGIRequest)):
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
 | 
					        elif isinstance(value, City):
 | 
				
			||||||
 | 
					            final_dict[key] = GEOIP_READER.city_to_dict(value)
 | 
				
			||||||
        elif isinstance(value, type):
 | 
					        elif isinstance(value, type):
 | 
				
			||||||
            final_dict[key] = {
 | 
					            final_dict[key] = {
 | 
				
			||||||
                "type": value.__name__,
 | 
					                "type": value.__name__,
 | 
				
			||||||
 | 
				
			|||||||
@ -94,9 +94,9 @@ class Command(BaseCommand):  # pragma: no cover
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def output_overview(self, values):
 | 
					    def output_overview(self, values):
 | 
				
			||||||
        """Output results human readable"""
 | 
					        """Output results human readable"""
 | 
				
			||||||
        total_max: int = max([max(inner) for inner in values])
 | 
					        total_max: int = max(max(inner) for inner in values)
 | 
				
			||||||
        total_min: int = min([min(inner) for inner in values])
 | 
					        total_min: int = min(min(inner) for inner in values)
 | 
				
			||||||
        total_avg = sum([sum(inner) for inner in values]) / sum([len(inner) for inner in values])
 | 
					        total_avg = sum(sum(inner) for inner in values) / sum(len(inner) for inner in values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print(f"Version: {__version__}")
 | 
					        print(f"Version: {__version__}")
 | 
				
			||||||
        print(f"Processes: {len(values)}")
 | 
					        print(f"Processes: {len(values)}")
 | 
				
			||||||
 | 
				
			|||||||
@ -117,7 +117,7 @@ class FlowPlanner:
 | 
				
			|||||||
        self.use_cache = True
 | 
					        self.use_cache = True
 | 
				
			||||||
        self.allow_empty_flows = False
 | 
					        self.allow_empty_flows = False
 | 
				
			||||||
        self.flow = flow
 | 
					        self.flow = flow
 | 
				
			||||||
        self._logger = get_logger().bind(flow=flow)
 | 
					        self._logger = get_logger().bind(flow_slug=flow.slug)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def plan(
 | 
					    def plan(
 | 
				
			||||||
        self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
 | 
					        self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
 | 
				
			||||||
 | 
				
			|||||||
@ -9,7 +9,7 @@ from django.urls import reverse
 | 
				
			|||||||
from django.views.generic.base import View
 | 
					from django.views.generic.base import View
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from sentry_sdk.hub import Hub
 | 
					from sentry_sdk.hub import Hub
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import BoundLogger, get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import DEFAULT_AVATAR, User
 | 
					from authentik.core.models import DEFAULT_AVATAR, User
 | 
				
			||||||
from authentik.flows.challenge import (
 | 
					from authentik.flows.challenge import (
 | 
				
			||||||
@ -23,23 +23,30 @@ from authentik.flows.challenge import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
from authentik.flows.models import InvalidResponseAction
 | 
					from authentik.flows.models import InvalidResponseAction
 | 
				
			||||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
 | 
					from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
 | 
				
			||||||
 | 
					from authentik.lib.utils.reflection import class_to_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TYPE_CHECKING:
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
    from authentik.flows.views.executor import FlowExecutorView
 | 
					    from authentik.flows.views.executor import FlowExecutorView
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier"
 | 
					PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier"
 | 
				
			||||||
LOGGER = get_logger()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StageView(View):
 | 
					class StageView(View):
 | 
				
			||||||
    """Abstract Stage, inherits TemplateView but can be combined with FormView"""
 | 
					    """Abstract Stage"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    executor: "FlowExecutorView"
 | 
					    executor: "FlowExecutorView"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    request: HttpRequest = None
 | 
					    request: HttpRequest = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logger: BoundLogger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, executor: "FlowExecutorView", **kwargs):
 | 
					    def __init__(self, executor: "FlowExecutorView", **kwargs):
 | 
				
			||||||
        self.executor = executor
 | 
					        self.executor = executor
 | 
				
			||||||
 | 
					        current_stage = getattr(self.executor, "current_stage", None)
 | 
				
			||||||
 | 
					        self.logger = get_logger().bind(
 | 
				
			||||||
 | 
					            stage=getattr(current_stage, "name", None),
 | 
				
			||||||
 | 
					            stage_view=class_to_path(type(self)),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        super().__init__(**kwargs)
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_pending_user(self, for_display=False) -> User:
 | 
					    def get_pending_user(self, for_display=False) -> User:
 | 
				
			||||||
@ -60,6 +67,9 @@ class StageView(View):
 | 
				
			|||||||
            return self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
 | 
					            return self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
 | 
				
			||||||
        return self.request.user
 | 
					        return self.request.user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def cleanup(self):
 | 
				
			||||||
 | 
					        """Cleanup session"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ChallengeStageView(StageView):
 | 
					class ChallengeStageView(StageView):
 | 
				
			||||||
    """Stage view which response with a challenge"""
 | 
					    """Stage view which response with a challenge"""
 | 
				
			||||||
@ -74,12 +84,9 @@ class ChallengeStageView(StageView):
 | 
				
			|||||||
        """Return a challenge for the frontend to solve"""
 | 
					        """Return a challenge for the frontend to solve"""
 | 
				
			||||||
        challenge = self._get_challenge(*args, **kwargs)
 | 
					        challenge = self._get_challenge(*args, **kwargs)
 | 
				
			||||||
        if not challenge.is_valid():
 | 
					        if not challenge.is_valid():
 | 
				
			||||||
            LOGGER.warning(
 | 
					            self.logger.warning(
 | 
				
			||||||
                "f(ch): Invalid challenge",
 | 
					                "f(ch): Invalid challenge",
 | 
				
			||||||
                binding=self.executor.current_binding,
 | 
					 | 
				
			||||||
                errors=challenge.errors,
 | 
					                errors=challenge.errors,
 | 
				
			||||||
                stage_view=self,
 | 
					 | 
				
			||||||
                challenge=challenge,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        return HttpChallengeResponse(challenge)
 | 
					        return HttpChallengeResponse(challenge)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -96,10 +103,8 @@ class ChallengeStageView(StageView):
 | 
				
			|||||||
                    self.executor.current_binding.invalid_response_action
 | 
					                    self.executor.current_binding.invalid_response_action
 | 
				
			||||||
                    == InvalidResponseAction.RESTART_WITH_CONTEXT
 | 
					                    == InvalidResponseAction.RESTART_WITH_CONTEXT
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                LOGGER.debug(
 | 
					                self.logger.debug(
 | 
				
			||||||
                    "f(ch): Invalid response, restarting flow",
 | 
					                    "f(ch): Invalid response, restarting flow",
 | 
				
			||||||
                    binding=self.executor.current_binding,
 | 
					 | 
				
			||||||
                    stage_view=self,
 | 
					 | 
				
			||||||
                    keep_context=keep_context,
 | 
					                    keep_context=keep_context,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                return self.executor.restart_flow(keep_context)
 | 
					                return self.executor.restart_flow(keep_context)
 | 
				
			||||||
@ -125,7 +130,7 @@ class ChallengeStageView(StageView):
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        # pylint: disable=broad-except
 | 
					        # pylint: disable=broad-except
 | 
				
			||||||
        except Exception as exc:
 | 
					        except Exception as exc:
 | 
				
			||||||
            LOGGER.warning("failed to template title", exc=exc)
 | 
					            self.logger.warning("failed to template title", exc=exc)
 | 
				
			||||||
            return self.executor.flow.title
 | 
					            return self.executor.flow.title
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_challenge(self, *args, **kwargs) -> Challenge:
 | 
					    def _get_challenge(self, *args, **kwargs) -> Challenge:
 | 
				
			||||||
@ -185,11 +190,9 @@ class ChallengeStageView(StageView):
 | 
				
			|||||||
                )
 | 
					                )
 | 
				
			||||||
        challenge_response.initial_data["response_errors"] = full_errors
 | 
					        challenge_response.initial_data["response_errors"] = full_errors
 | 
				
			||||||
        if not challenge_response.is_valid():
 | 
					        if not challenge_response.is_valid():
 | 
				
			||||||
            LOGGER.error(
 | 
					            self.logger.error(
 | 
				
			||||||
                "f(ch): invalid challenge response",
 | 
					                "f(ch): invalid challenge response",
 | 
				
			||||||
                binding=self.executor.current_binding,
 | 
					 | 
				
			||||||
                errors=challenge_response.errors,
 | 
					                errors=challenge_response.errors,
 | 
				
			||||||
                stage_view=self,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        return HttpChallengeResponse(challenge_response)
 | 
					        return HttpChallengeResponse(challenge_response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -28,6 +28,7 @@ ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt)
 | 
				
			|||||||
def transaction_rollback():
 | 
					def transaction_rollback():
 | 
				
			||||||
    """Enters an atomic transaction and always triggers a rollback at the end of the block."""
 | 
					    """Enters an atomic transaction and always triggers a rollback at the end of the block."""
 | 
				
			||||||
    atomic = transaction.atomic()
 | 
					    atomic = transaction.atomic()
 | 
				
			||||||
 | 
					    # pylint: disable=unnecessary-dunder-call
 | 
				
			||||||
    atomic.__enter__()
 | 
					    atomic.__enter__()
 | 
				
			||||||
    yield
 | 
					    yield
 | 
				
			||||||
    atomic.__exit__(IntegrityError, None, None)
 | 
					    atomic.__exit__(IntegrityError, None, None)
 | 
				
			||||||
 | 
				
			|||||||
@ -49,7 +49,7 @@ from authentik.flows.planner import (
 | 
				
			|||||||
    FlowPlan,
 | 
					    FlowPlan,
 | 
				
			||||||
    FlowPlanner,
 | 
					    FlowPlanner,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.flows.stage import AccessDeniedChallengeView
 | 
					from authentik.flows.stage import AccessDeniedChallengeView, StageView
 | 
				
			||||||
from authentik.lib.sentry import SentryIgnoredException
 | 
					from authentik.lib.sentry import SentryIgnoredException
 | 
				
			||||||
from authentik.lib.utils.errors import exception_to_string
 | 
					from authentik.lib.utils.errors import exception_to_string
 | 
				
			||||||
from authentik.lib.utils.reflection import all_subclasses, class_to_path
 | 
					from authentik.lib.utils.reflection import all_subclasses, class_to_path
 | 
				
			||||||
@ -59,11 +59,11 @@ from authentik.tenants.models import Tenant
 | 
				
			|||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
# Argument used to redirect user after login
 | 
					# Argument used to redirect user after login
 | 
				
			||||||
NEXT_ARG_NAME = "next"
 | 
					NEXT_ARG_NAME = "next"
 | 
				
			||||||
SESSION_KEY_PLAN = "authentik_flows_plan"
 | 
					SESSION_KEY_PLAN = "authentik/flows/plan"
 | 
				
			||||||
SESSION_KEY_APPLICATION_PRE = "authentik_flows_application_pre"
 | 
					SESSION_KEY_APPLICATION_PRE = "authentik/flows/application_pre"
 | 
				
			||||||
SESSION_KEY_GET = "authentik_flows_get"
 | 
					SESSION_KEY_GET = "authentik/flows/get"
 | 
				
			||||||
SESSION_KEY_POST = "authentik_flows_post"
 | 
					SESSION_KEY_POST = "authentik/flows/post"
 | 
				
			||||||
SESSION_KEY_HISTORY = "authentik_flows_history"
 | 
					SESSION_KEY_HISTORY = "authentik/flows/history"
 | 
				
			||||||
QS_KEY_TOKEN = "flow_token"  # nosec
 | 
					QS_KEY_TOKEN = "flow_token"  # nosec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -380,6 +380,8 @@ class FlowExecutorView(APIView):
 | 
				
			|||||||
            "f(exec): Stage ok",
 | 
					            "f(exec): Stage ok",
 | 
				
			||||||
            stage_class=class_to_path(self.current_stage_view.__class__),
 | 
					            stage_class=class_to_path(self.current_stage_view.__class__),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        if isinstance(self.current_stage_view, StageView):
 | 
				
			||||||
 | 
					            self.current_stage_view.cleanup()
 | 
				
			||||||
        self.request.session.get(SESSION_KEY_HISTORY, []).append(deepcopy(self.plan))
 | 
					        self.request.session.get(SESSION_KEY_HISTORY, []).append(deepcopy(self.plan))
 | 
				
			||||||
        self.plan.pop()
 | 
					        self.plan.pop()
 | 
				
			||||||
        self.request.session[SESSION_KEY_PLAN] = self.plan
 | 
					        self.request.session[SESSION_KEY_PLAN] = self.plan
 | 
				
			||||||
@ -416,11 +418,14 @@ class FlowExecutorView(APIView):
 | 
				
			|||||||
            SESSION_KEY_APPLICATION_PRE,
 | 
					            SESSION_KEY_APPLICATION_PRE,
 | 
				
			||||||
            SESSION_KEY_PLAN,
 | 
					            SESSION_KEY_PLAN,
 | 
				
			||||||
            SESSION_KEY_GET,
 | 
					            SESSION_KEY_GET,
 | 
				
			||||||
 | 
					            # We might need the initial POST payloads for later requests
 | 
				
			||||||
 | 
					            # SESSION_KEY_POST,
 | 
				
			||||||
            # We don't delete the history on purpose, as a user might
 | 
					            # We don't delete the history on purpose, as a user might
 | 
				
			||||||
            # still be inspecting it.
 | 
					            # still be inspecting it.
 | 
				
			||||||
            # It's only deleted on a fresh executions
 | 
					            # It's only deleted on a fresh executions
 | 
				
			||||||
            # SESSION_KEY_HISTORY,
 | 
					            # SESSION_KEY_HISTORY,
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					        self._logger.debug("f(exec): cleaning up")
 | 
				
			||||||
        for key in keys_to_delete:
 | 
					        for key in keys_to_delete:
 | 
				
			||||||
            if key in self.request.session:
 | 
					            if key in self.request.session:
 | 
				
			||||||
                del self.request.session[key]
 | 
					                del self.request.session[key]
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										12
									
								
								authentik/lib/xml.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								authentik/lib/xml.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					"""XML Utilities"""
 | 
				
			||||||
 | 
					from lxml.etree import XMLParser, fromstring  # nosec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_lxml_parser():
 | 
				
			||||||
 | 
					    """Get XML parser"""
 | 
				
			||||||
 | 
					    return XMLParser(resolve_entities=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def lxml_from_string(text: str):
 | 
				
			||||||
 | 
					    """Wrapper around fromstring"""
 | 
				
			||||||
 | 
					    return fromstring(text, parser=get_lxml_parser())
 | 
				
			||||||
@ -8,9 +8,3 @@ class AuthentikManagedConfig(AppConfig):
 | 
				
			|||||||
    name = "authentik.managed"
 | 
					    name = "authentik.managed"
 | 
				
			||||||
    label = "authentik_managed"
 | 
					    label = "authentik_managed"
 | 
				
			||||||
    verbose_name = "authentik Managed"
 | 
					    verbose_name = "authentik Managed"
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def ready(self) -> None:
 | 
					 | 
				
			||||||
        from authentik.managed.tasks import managed_reconcile
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # pyright: reportGeneralTypeIssues=false
 | 
					 | 
				
			||||||
        managed_reconcile.delay()  # pylint: disable=no-value-for-parameter
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,6 @@
 | 
				
			|||||||
from importlib import import_module
 | 
					from importlib import import_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.apps import AppConfig
 | 
					from django.apps import AppConfig
 | 
				
			||||||
from django.db import ProgrammingError
 | 
					 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
@ -18,10 +17,3 @@ class AuthentikOutpostConfig(AppConfig):
 | 
				
			|||||||
    def ready(self):
 | 
					    def ready(self):
 | 
				
			||||||
        import_module("authentik.outposts.signals")
 | 
					        import_module("authentik.outposts.signals")
 | 
				
			||||||
        import_module("authentik.outposts.managed")
 | 
					        import_module("authentik.outposts.managed")
 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            from authentik.outposts.tasks import outpost_controller_all, outpost_local_connection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            outpost_local_connection.delay()
 | 
					 | 
				
			||||||
            outpost_controller_all.delay()
 | 
					 | 
				
			||||||
        except ProgrammingError:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -48,9 +48,7 @@ class PolicySerializer(ModelSerializer, MetaNameSerializer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def get_bound_to(self, obj: Policy) -> int:
 | 
					    def get_bound_to(self, obj: Policy) -> int:
 | 
				
			||||||
        """Return objects policy is bound to"""
 | 
					        """Return objects policy is bound to"""
 | 
				
			||||||
        if not obj.bindings.exists() and not obj.promptstage_set.exists():
 | 
					        return obj.bindings.count() + obj.promptstage_set.count()
 | 
				
			||||||
            return 0
 | 
					 | 
				
			||||||
        return obj.bindings.count()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_representation(self, instance: Policy):
 | 
					    def to_representation(self, instance: Policy):
 | 
				
			||||||
        # pyright: reportGeneralTypeIssues=false
 | 
					        # pyright: reportGeneralTypeIssues=false
 | 
				
			||||||
 | 
				
			|||||||
@ -23,7 +23,7 @@ GAUGE_POLICIES_CACHED = Gauge(
 | 
				
			|||||||
HIST_POLICIES_BUILD_TIME = Histogram(
 | 
					HIST_POLICIES_BUILD_TIME = Histogram(
 | 
				
			||||||
    "authentik_policies_build_time",
 | 
					    "authentik_policies_build_time",
 | 
				
			||||||
    "Execution times complete policy result to an object",
 | 
					    "Execution times complete policy result to an object",
 | 
				
			||||||
    ["object_name", "object_type", "user"],
 | 
					    ["object_pk", "object_type"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -91,9 +91,8 @@ class PolicyEngine:
 | 
				
			|||||||
            op="authentik.policy.engine.build",
 | 
					            op="authentik.policy.engine.build",
 | 
				
			||||||
            description=self.__pbm,
 | 
					            description=self.__pbm,
 | 
				
			||||||
        ) as span, HIST_POLICIES_BUILD_TIME.labels(
 | 
					        ) as span, HIST_POLICIES_BUILD_TIME.labels(
 | 
				
			||||||
            object_name=self.__pbm,
 | 
					            object_pk=str(self.__pbm.pk),
 | 
				
			||||||
            object_type=f"{self.__pbm._meta.app_label}.{self.__pbm._meta.model_name}",
 | 
					            object_type=f"{self.__pbm._meta.app_label}.{self.__pbm._meta.model_name}",
 | 
				
			||||||
            user=self.request.user,
 | 
					 | 
				
			||||||
        ).time():
 | 
					        ).time():
 | 
				
			||||||
            span: Span
 | 
					            span: Span
 | 
				
			||||||
            span.set_data("pbm", self.__pbm)
 | 
					            span.set_data("pbm", self.__pbm)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,8 @@
 | 
				
			|||||||
"""Password flow tests"""
 | 
					"""Password flow tests"""
 | 
				
			||||||
from django.urls.base import reverse
 | 
					from django.urls.base import reverse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
				
			||||||
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
 | 
					from authentik.flows.models import FlowDesignation, FlowStageBinding
 | 
				
			||||||
from authentik.flows.tests import FlowTestCase
 | 
					from authentik.flows.tests import FlowTestCase
 | 
				
			||||||
from authentik.policies.password.models import PasswordPolicy
 | 
					from authentik.policies.password.models import PasswordPolicy
 | 
				
			||||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
 | 
					from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
 | 
				
			||||||
@ -12,13 +12,9 @@ class TestPasswordPolicyFlow(FlowTestCase):
 | 
				
			|||||||
    """Test Password Policy"""
 | 
					    """Test Password Policy"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self) -> None:
 | 
					    def setUp(self) -> None:
 | 
				
			||||||
        self.user = User.objects.create(username="unittest", email="test@beryju.org")
 | 
					        self.user = create_test_admin_user()
 | 
				
			||||||
 | 
					        self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.flow = Flow.objects.create(
 | 
					 | 
				
			||||||
            name="test-prompt",
 | 
					 | 
				
			||||||
            slug="test-prompt",
 | 
					 | 
				
			||||||
            designation=FlowDesignation.AUTHENTICATION,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        password_prompt = Prompt.objects.create(
 | 
					        password_prompt = Prompt.objects.create(
 | 
				
			||||||
            field_key="password",
 | 
					            field_key="password",
 | 
				
			||||||
            label="PASSWORD_LABEL",
 | 
					            label="PASSWORD_LABEL",
 | 
				
			||||||
 | 
				
			|||||||
@ -28,9 +28,8 @@ HIST_POLICIES_EXECUTION_TIME = Histogram(
 | 
				
			|||||||
        "binding_order",
 | 
					        "binding_order",
 | 
				
			||||||
        "binding_target_type",
 | 
					        "binding_target_type",
 | 
				
			||||||
        "binding_target_name",
 | 
					        "binding_target_name",
 | 
				
			||||||
        "object_name",
 | 
					        "object_pk",
 | 
				
			||||||
        "object_type",
 | 
					        "object_type",
 | 
				
			||||||
        "user",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -89,7 +88,7 @@ class PolicyProcess(PROCESS_CLASS):
 | 
				
			|||||||
        LOGGER.debug(
 | 
					        LOGGER.debug(
 | 
				
			||||||
            "P_ENG(proc): Running policy",
 | 
					            "P_ENG(proc): Running policy",
 | 
				
			||||||
            policy=self.binding.policy,
 | 
					            policy=self.binding.policy,
 | 
				
			||||||
            user=self.request.user,
 | 
					            user=self.request.user.username,
 | 
				
			||||||
            # this is used for filtering in access checking where logs are sent to the admin
 | 
					            # this is used for filtering in access checking where logs are sent to the admin
 | 
				
			||||||
            process="PolicyProcess",
 | 
					            process="PolicyProcess",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -125,7 +124,7 @@ class PolicyProcess(PROCESS_CLASS):
 | 
				
			|||||||
            # this is used for filtering in access checking where logs are sent to the admin
 | 
					            # this is used for filtering in access checking where logs are sent to the admin
 | 
				
			||||||
            process="PolicyProcess",
 | 
					            process="PolicyProcess",
 | 
				
			||||||
            passing=policy_result.passing,
 | 
					            passing=policy_result.passing,
 | 
				
			||||||
            user=self.request.user,
 | 
					            user=self.request.user.username,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        return policy_result
 | 
					        return policy_result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -137,9 +136,8 @@ class PolicyProcess(PROCESS_CLASS):
 | 
				
			|||||||
            binding_order=self.binding.order,
 | 
					            binding_order=self.binding.order,
 | 
				
			||||||
            binding_target_type=self.binding.target_type,
 | 
					            binding_target_type=self.binding.target_type,
 | 
				
			||||||
            binding_target_name=self.binding.target_name,
 | 
					            binding_target_name=self.binding.target_name,
 | 
				
			||||||
            object_name=self.request.obj,
 | 
					            object_pk=str(self.request.obj.pk),
 | 
				
			||||||
            object_type=f"{self.request.obj._meta.app_label}.{self.request.obj._meta.model_name}",
 | 
					            object_type=f"{self.request.obj._meta.app_label}.{self.request.obj._meta.model_name}",
 | 
				
			||||||
            user=str(self.request.user),
 | 
					 | 
				
			||||||
        ).time():
 | 
					        ).time():
 | 
				
			||||||
            span: Span
 | 
					            span: Span
 | 
				
			||||||
            span.set_data("policy", self.binding.policy)
 | 
					            span.set_data("policy", self.binding.policy)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,10 +1,11 @@
 | 
				
			|||||||
"""authentik reputation request signals"""
 | 
					"""authentik reputation request signals"""
 | 
				
			||||||
from django.contrib.auth.signals import user_logged_in, user_login_failed
 | 
					from django.contrib.auth.signals import user_logged_in
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from authentik.core.signals import login_failed
 | 
				
			||||||
from authentik.lib.config import CONFIG
 | 
					from authentik.lib.config import CONFIG
 | 
				
			||||||
from authentik.lib.utils.http import get_client_ip
 | 
					from authentik.lib.utils.http import get_client_ip
 | 
				
			||||||
from authentik.policies.reputation.models import CACHE_KEY_PREFIX
 | 
					from authentik.policies.reputation.models import CACHE_KEY_PREFIX
 | 
				
			||||||
@ -35,7 +36,7 @@ def update_score(request: HttpRequest, identifier: str, amount: int):
 | 
				
			|||||||
    save_reputation.delay()
 | 
					    save_reputation.delay()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(user_login_failed)
 | 
					@receiver(login_failed)
 | 
				
			||||||
# pylint: disable=unused-argument
 | 
					# pylint: disable=unused-argument
 | 
				
			||||||
def handle_failed_login(sender, request, credentials, **_):
 | 
					def handle_failed_login(sender, request, credentials, **_):
 | 
				
			||||||
    """Lower Score for failed login attempts"""
 | 
					    """Lower Score for failed login attempts"""
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,4 @@
 | 
				
			|||||||
"""test reputation signals and policy"""
 | 
					"""test reputation signals and policy"""
 | 
				
			||||||
from django.contrib.auth import authenticate
 | 
					 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
from django.test import RequestFactory, TestCase
 | 
					from django.test import RequestFactory, TestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -7,6 +6,8 @@ from authentik.core.models import User
 | 
				
			|||||||
from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy
 | 
					from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy
 | 
				
			||||||
from authentik.policies.reputation.tasks import save_reputation
 | 
					from authentik.policies.reputation.tasks import save_reputation
 | 
				
			||||||
from authentik.policies.types import PolicyRequest
 | 
					from authentik.policies.types import PolicyRequest
 | 
				
			||||||
 | 
					from authentik.stages.password import BACKEND_INBUILT
 | 
				
			||||||
 | 
					from authentik.stages.password.stage import authenticate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestReputationPolicy(TestCase):
 | 
					class TestReputationPolicy(TestCase):
 | 
				
			||||||
@ -21,11 +22,14 @@ class TestReputationPolicy(TestCase):
 | 
				
			|||||||
        cache.delete_many(keys)
 | 
					        cache.delete_many(keys)
 | 
				
			||||||
        # We need a user for the one-to-one in userreputation
 | 
					        # We need a user for the one-to-one in userreputation
 | 
				
			||||||
        self.user = User.objects.create(username=self.test_username)
 | 
					        self.user = User.objects.create(username=self.test_username)
 | 
				
			||||||
 | 
					        self.backends = [BACKEND_INBUILT]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_ip_reputation(self):
 | 
					    def test_ip_reputation(self):
 | 
				
			||||||
        """test IP reputation"""
 | 
					        """test IP reputation"""
 | 
				
			||||||
        # Trigger negative reputation
 | 
					        # Trigger negative reputation
 | 
				
			||||||
        authenticate(self.request, username=self.test_username, password=self.test_username)
 | 
					        authenticate(
 | 
				
			||||||
 | 
					            self.request, self.backends, username=self.test_username, password=self.test_username
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        # Test value in cache
 | 
					        # Test value in cache
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            cache.get(CACHE_KEY_PREFIX + self.test_ip + self.test_username),
 | 
					            cache.get(CACHE_KEY_PREFIX + self.test_ip + self.test_username),
 | 
				
			||||||
@ -38,7 +42,9 @@ class TestReputationPolicy(TestCase):
 | 
				
			|||||||
    def test_user_reputation(self):
 | 
					    def test_user_reputation(self):
 | 
				
			||||||
        """test User reputation"""
 | 
					        """test User reputation"""
 | 
				
			||||||
        # Trigger negative reputation
 | 
					        # Trigger negative reputation
 | 
				
			||||||
        authenticate(self.request, username=self.test_username, password=self.test_username)
 | 
					        authenticate(
 | 
				
			||||||
 | 
					            self.request, self.backends, username=self.test_username, password=self.test_username
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        # Test value in cache
 | 
					        # Test value in cache
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            cache.get(CACHE_KEY_PREFIX + self.test_ip + self.test_username),
 | 
					            cache.get(CACHE_KEY_PREFIX + self.test_ip + self.test_username),
 | 
				
			||||||
 | 
				
			|||||||
@ -35,6 +35,7 @@ class OAuth2ProviderSerializer(ProviderSerializer):
 | 
				
			|||||||
            "property_mappings",
 | 
					            "property_mappings",
 | 
				
			||||||
            "issuer_mode",
 | 
					            "issuer_mode",
 | 
				
			||||||
            "verification_keys",
 | 
					            "verification_keys",
 | 
				
			||||||
 | 
					            "jwks_sources",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -47,6 +48,7 @@ class OAuth2ProviderSetupURLs(PassiveSerializer):
 | 
				
			|||||||
    user_info = CharField(read_only=True)
 | 
					    user_info = CharField(read_only=True)
 | 
				
			||||||
    provider_info = CharField(read_only=True)
 | 
					    provider_info = CharField(read_only=True)
 | 
				
			||||||
    logout = CharField(read_only=True)
 | 
					    logout = CharField(read_only=True)
 | 
				
			||||||
 | 
					    jwks = CharField(read_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet):
 | 
					class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet):
 | 
				
			||||||
@ -118,6 +120,12 @@ class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
                    kwargs={"application_slug": provider.application.slug},
 | 
					                    kwargs={"application_slug": provider.application.slug},
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            data["jwks"] = request.build_absolute_uri(
 | 
				
			||||||
 | 
					                reverse(
 | 
				
			||||||
 | 
					                    "authentik_providers_oauth2:jwks",
 | 
				
			||||||
 | 
					                    kwargs={"application_slug": provider.application.slug},
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        except Provider.application.RelatedObjectDoesNotExist:  # pylint: disable=no-member
 | 
					        except Provider.application.RelatedObjectDoesNotExist:  # pylint: disable=no-member
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
        return Response(data)
 | 
					        return Response(data)
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,7 @@ class Migration(migrations.Migration):
 | 
				
			|||||||
            model_name="oauth2provider",
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
            name="verification_keys",
 | 
					            name="verification_keys",
 | 
				
			||||||
            field=models.ManyToManyField(
 | 
					            field=models.ManyToManyField(
 | 
				
			||||||
                help_text="JWTs created with the configured certificates can authenticate with this provider.",
 | 
					                help_text="DEPRECATED. JWTs created with the configured certificates can authenticate with this provider.",
 | 
				
			||||||
                related_name="+",
 | 
					                related_name="+",
 | 
				
			||||||
                to="authentik_crypto.certificatekeypair",
 | 
					                to="authentik_crypto.certificatekeypair",
 | 
				
			||||||
                verbose_name="Allowed certificates for JWT-based client_credentials",
 | 
					                verbose_name="Allowed certificates for JWT-based client_credentials",
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,41 @@
 | 
				
			|||||||
 | 
					# Generated by Django 4.0.4 on 2022-05-23 20:41
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            "authentik_sources_oauth",
 | 
				
			||||||
 | 
					            "0007_oauthsource_oidc_jwks_oauthsource_oidc_jwks_url_and_more",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        ("authentik_crypto", "0003_certificatekeypair_managed"),
 | 
				
			||||||
 | 
					        ("authentik_providers_oauth2", "0010_alter_oauth2provider_verification_keys"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
 | 
					            name="jwks_sources",
 | 
				
			||||||
 | 
					            field=models.ManyToManyField(
 | 
				
			||||||
 | 
					                blank=True,
 | 
				
			||||||
 | 
					                default=None,
 | 
				
			||||||
 | 
					                related_name="oauth2_providers",
 | 
				
			||||||
 | 
					                to="authentik_sources_oauth.oauthsource",
 | 
				
			||||||
 | 
					                verbose_name="Any JWT signed by the JWK of the selected source can be used to authenticate.",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AlterField(
 | 
				
			||||||
 | 
					            model_name="oauth2provider",
 | 
				
			||||||
 | 
					            name="verification_keys",
 | 
				
			||||||
 | 
					            field=models.ManyToManyField(
 | 
				
			||||||
 | 
					                blank=True,
 | 
				
			||||||
 | 
					                default=None,
 | 
				
			||||||
 | 
					                help_text="DEPRECATED. JWTs created with the configured certificates can authenticate with this provider.",
 | 
				
			||||||
 | 
					                related_name="oauth2_providers",
 | 
				
			||||||
 | 
					                to="authentik_crypto.certificatekeypair",
 | 
				
			||||||
 | 
					                verbose_name="Allowed certificates for JWT-based client_credentials",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -27,6 +27,7 @@ from authentik.lib.generators import generate_id, generate_key
 | 
				
			|||||||
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
 | 
					from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
 | 
				
			||||||
from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config
 | 
					from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config
 | 
				
			||||||
from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
 | 
					from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
 | 
				
			||||||
 | 
					from authentik.sources.oauth.models import OAuthSource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ClientTypes(models.TextChoices):
 | 
					class ClientTypes(models.TextChoices):
 | 
				
			||||||
@ -225,9 +226,21 @@ class OAuth2Provider(Provider):
 | 
				
			|||||||
        CertificateKeyPair,
 | 
					        CertificateKeyPair,
 | 
				
			||||||
        verbose_name=_("Allowed certificates for JWT-based client_credentials"),
 | 
					        verbose_name=_("Allowed certificates for JWT-based client_credentials"),
 | 
				
			||||||
        help_text=_(
 | 
					        help_text=_(
 | 
				
			||||||
            "JWTs created with the configured certificates can authenticate with this provider."
 | 
					            (
 | 
				
			||||||
 | 
					                "DEPRECATED. JWTs created with the configured "
 | 
				
			||||||
 | 
					                "certificates can authenticate with this provider."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
        related_name="+",
 | 
					        related_name="oauth2_providers",
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        blank=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    jwks_sources = models.ManyToManyField(
 | 
				
			||||||
 | 
					        OAuthSource,
 | 
				
			||||||
 | 
					        verbose_name=_(
 | 
				
			||||||
 | 
					            "Any JWT signed by the JWK of the selected source can be used to authenticate."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        related_name="oauth2_providers",
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
        blank=True,
 | 
					        blank=True,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
@ -6,8 +6,8 @@ from django.test import RequestFactory
 | 
				
			|||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from jwt import decode
 | 
					from jwt import decode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import USER_ATTRIBUTE_SA, Application, Group
 | 
					from authentik.core.models import Application, Group
 | 
				
			||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
				
			||||||
from authentik.lib.generators import generate_id, generate_key
 | 
					from authentik.lib.generators import generate_id, generate_key
 | 
				
			||||||
from authentik.managed.manager import ObjectManager
 | 
					from authentik.managed.manager import ObjectManager
 | 
				
			||||||
from authentik.policies.models import PolicyBinding
 | 
					from authentik.policies.models import PolicyBinding
 | 
				
			||||||
@ -40,9 +40,6 @@ class TestTokenClientCredentialsJWT(OAuthTestCase):
 | 
				
			|||||||
        self.provider.verification_keys.set([self.cert])
 | 
					        self.provider.verification_keys.set([self.cert])
 | 
				
			||||||
        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
        self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
 | 
					        self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
 | 
				
			||||||
        self.user = create_test_admin_user("sa")
 | 
					 | 
				
			||||||
        self.user.attributes[USER_ATTRIBUTE_SA] = True
 | 
					 | 
				
			||||||
        self.user.save()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invalid_type(self):
 | 
					    def test_invalid_type(self):
 | 
				
			||||||
        """test invalid type"""
 | 
					        """test invalid type"""
 | 
				
			||||||
@ -76,7 +73,7 @@ class TestTokenClientCredentialsJWT(OAuthTestCase):
 | 
				
			|||||||
        body = loads(response.content.decode())
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
        self.assertEqual(body["error"], "invalid_grant")
 | 
					        self.assertEqual(body["error"], "invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invalid_signautre(self):
 | 
					    def test_invalid_signature(self):
 | 
				
			||||||
        """test invalid JWT"""
 | 
					        """test invalid JWT"""
 | 
				
			||||||
        token = self.provider.encode(
 | 
					        token = self.provider.encode(
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
							
								
								
									
										223
									
								
								authentik/providers/oauth2/tests/test_token_cc_jwt_source.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										223
									
								
								authentik/providers/oauth2/tests/test_token_cc_jwt_source.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,223 @@
 | 
				
			|||||||
 | 
					"""Test token view"""
 | 
				
			||||||
 | 
					from datetime import datetime, timedelta
 | 
				
			||||||
 | 
					from json import loads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import RequestFactory
 | 
				
			||||||
 | 
					from django.urls import reverse
 | 
				
			||||||
 | 
					from jwt import decode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from authentik.core.models import Application, Group
 | 
				
			||||||
 | 
					from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
				
			||||||
 | 
					from authentik.lib.generators import generate_id, generate_key
 | 
				
			||||||
 | 
					from authentik.managed.manager import ObjectManager
 | 
				
			||||||
 | 
					from authentik.policies.models import PolicyBinding
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.constants import (
 | 
				
			||||||
 | 
					    GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					    SCOPE_OPENID,
 | 
				
			||||||
 | 
					    SCOPE_OPENID_EMAIL,
 | 
				
			||||||
 | 
					    SCOPE_OPENID_PROFILE,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.models import OAuth2Provider, ScopeMapping
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
				
			||||||
 | 
					from authentik.providers.oauth2.views.jwks import JWKSView
 | 
				
			||||||
 | 
					from authentik.sources.oauth.models import OAuthSource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestTokenClientCredentialsJWTSource(OAuthTestCase):
 | 
				
			||||||
 | 
					    """Test token (client_credentials, with JWT) view"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self) -> None:
 | 
				
			||||||
 | 
					        super().setUp()
 | 
				
			||||||
 | 
					        ObjectManager().run()
 | 
				
			||||||
 | 
					        self.factory = RequestFactory()
 | 
				
			||||||
 | 
					        self.cert = create_test_cert()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        jwk = JWKSView().get_jwk_for_key(self.cert)
 | 
				
			||||||
 | 
					        self.source: OAuthSource = OAuthSource.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            slug=generate_id(),
 | 
				
			||||||
 | 
					            provider_type="openidconnect",
 | 
				
			||||||
 | 
					            consumer_key=generate_id(),
 | 
				
			||||||
 | 
					            consumer_secret=generate_key(),
 | 
				
			||||||
 | 
					            authorization_url="http://foo",
 | 
				
			||||||
 | 
					            access_token_url=f"http://{generate_id()}",
 | 
				
			||||||
 | 
					            profile_url="http://foo",
 | 
				
			||||||
 | 
					            oidc_well_known_url="",
 | 
				
			||||||
 | 
					            oidc_jwks_url="",
 | 
				
			||||||
 | 
					            oidc_jwks={
 | 
				
			||||||
 | 
					                "keys": [jwk],
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.provider: OAuth2Provider = OAuth2Provider.objects.create(
 | 
				
			||||||
 | 
					            name="test",
 | 
				
			||||||
 | 
					            client_id=generate_id(),
 | 
				
			||||||
 | 
					            client_secret=generate_key(),
 | 
				
			||||||
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
 | 
					            redirect_uris="http://testserver",
 | 
				
			||||||
 | 
					            signing_key=self.cert,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.provider.jwks_sources.add(self.source)
 | 
				
			||||||
 | 
					        self.provider.property_mappings.set(ScopeMapping.objects.all())
 | 
				
			||||||
 | 
					        self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_type(self):
 | 
				
			||||||
 | 
					        """test invalid type"""
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_assertion_type": "foo",
 | 
				
			||||||
 | 
					                "client_assertion": "foo.bar",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 400)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["error"], "invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_jwt(self):
 | 
				
			||||||
 | 
					        """test invalid JWT"""
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
 | 
				
			||||||
 | 
					                "client_assertion": "foo.bar",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 400)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["error"], "invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_signature(self):
 | 
				
			||||||
 | 
					        """test invalid JWT"""
 | 
				
			||||||
 | 
					        token = self.provider.encode(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "sub": "foo",
 | 
				
			||||||
 | 
					                "exp": datetime.now() + timedelta(hours=2),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
 | 
				
			||||||
 | 
					                "client_assertion": token + "foo",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 400)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["error"], "invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_expired(self):
 | 
				
			||||||
 | 
					        """test invalid JWT"""
 | 
				
			||||||
 | 
					        token = self.provider.encode(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "sub": "foo",
 | 
				
			||||||
 | 
					                "exp": datetime.now() - timedelta(hours=2),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
 | 
				
			||||||
 | 
					                "client_assertion": token,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 400)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["error"], "invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_no_app(self):
 | 
				
			||||||
 | 
					        """test invalid JWT"""
 | 
				
			||||||
 | 
					        self.app.provider = None
 | 
				
			||||||
 | 
					        self.app.save()
 | 
				
			||||||
 | 
					        token = self.provider.encode(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "sub": "foo",
 | 
				
			||||||
 | 
					                "exp": datetime.now() + timedelta(hours=2),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
 | 
				
			||||||
 | 
					                "client_assertion": token,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 400)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["error"], "invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_access_denied(self):
 | 
				
			||||||
 | 
					        """test invalid JWT"""
 | 
				
			||||||
 | 
					        group = Group.objects.create(name="foo")
 | 
				
			||||||
 | 
					        PolicyBinding.objects.create(
 | 
				
			||||||
 | 
					            group=group,
 | 
				
			||||||
 | 
					            target=self.app,
 | 
				
			||||||
 | 
					            order=0,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        token = self.provider.encode(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "sub": "foo",
 | 
				
			||||||
 | 
					                "exp": datetime.now() + timedelta(hours=2),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
 | 
				
			||||||
 | 
					                "client_assertion": token,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 400)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["error"], "invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_successful(self):
 | 
				
			||||||
 | 
					        """test successful"""
 | 
				
			||||||
 | 
					        token = self.provider.encode(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "sub": "foo",
 | 
				
			||||||
 | 
					                "exp": datetime.now() + timedelta(hours=2),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            reverse("authentik_providers_oauth2:token"),
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
 | 
				
			||||||
 | 
					                "client_id": self.provider.client_id,
 | 
				
			||||||
 | 
					                "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
 | 
				
			||||||
 | 
					                "client_assertion": token,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					        body = loads(response.content.decode())
 | 
				
			||||||
 | 
					        self.assertEqual(body["token_type"], "bearer")
 | 
				
			||||||
 | 
					        _, alg = self.provider.get_jwt_key()
 | 
				
			||||||
 | 
					        jwt = decode(
 | 
				
			||||||
 | 
					            body["access_token"],
 | 
				
			||||||
 | 
					            key=self.provider.signing_key.public_key,
 | 
				
			||||||
 | 
					            algorithms=[alg],
 | 
				
			||||||
 | 
					            audience=self.provider.client_id,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            jwt["given_name"], "Autogenerated user from application test (client credentials JWT)"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(jwt["preferred_username"], "test-foo")
 | 
				
			||||||
@ -69,7 +69,7 @@ from authentik.stages.user_login.stage import USER_LOGIN_AUTHENTICATED
 | 
				
			|||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PLAN_CONTEXT_PARAMS = "params"
 | 
					PLAN_CONTEXT_PARAMS = "params"
 | 
				
			||||||
SESSION_NEEDS_LOGIN = "authentik_oauth2_needs_login"
 | 
					SESSION_KEY_NEEDS_LOGIN = "authentik/providers/oauth2/needs_login"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN}
 | 
					ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -196,7 +196,7 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
                LOGGER.warning(
 | 
					                LOGGER.warning(
 | 
				
			||||||
                    "Invalid redirect uri",
 | 
					                    "Invalid redirect uri",
 | 
				
			||||||
                    redirect_uri=self.redirect_uri,
 | 
					                    redirect_uri=self.redirect_uri,
 | 
				
			||||||
                    excepted=allowed_redirect_urls,
 | 
					                    expected=allowed_redirect_urls,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
					                raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
				
			||||||
        except RegexError as exc:
 | 
					        except RegexError as exc:
 | 
				
			||||||
@ -326,13 +326,13 @@ class AuthorizationFlowInitView(PolicyAccessView):
 | 
				
			|||||||
        # If prompt=login, we need to re-authenticate the user regardless
 | 
					        # If prompt=login, we need to re-authenticate the user regardless
 | 
				
			||||||
        if (
 | 
					        if (
 | 
				
			||||||
            PROMPT_LOGIN in self.params.prompt
 | 
					            PROMPT_LOGIN in self.params.prompt
 | 
				
			||||||
            and SESSION_NEEDS_LOGIN not in self.request.session
 | 
					            and SESSION_KEY_NEEDS_LOGIN not in self.request.session
 | 
				
			||||||
            # To prevent the user from having to double login when prompt is set to login
 | 
					            # To prevent the user from having to double login when prompt is set to login
 | 
				
			||||||
            # and the user has just signed it. This session variable is set in the UserLoginStage
 | 
					            # and the user has just signed it. This session variable is set in the UserLoginStage
 | 
				
			||||||
            # and is (quite hackily) removed from the session in applications's API's List method
 | 
					            # and is (quite hackily) removed from the session in applications's API's List method
 | 
				
			||||||
            and USER_LOGIN_AUTHENTICATED not in self.request.session
 | 
					            and USER_LOGIN_AUTHENTICATED not in self.request.session
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            self.request.session[SESSION_NEEDS_LOGIN] = True
 | 
					            self.request.session[SESSION_KEY_NEEDS_LOGIN] = True
 | 
				
			||||||
            return self.handle_no_permission()
 | 
					            return self.handle_no_permission()
 | 
				
			||||||
        # Regardless, we start the planner and return to it
 | 
					        # Regardless, we start the planner and return to it
 | 
				
			||||||
        planner = FlowPlanner(self.provider.authorization_flow)
 | 
					        planner = FlowPlanner(self.provider.authorization_flow)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,6 @@
 | 
				
			|||||||
"""authentik OAuth2 JWKS Views"""
 | 
					"""authentik OAuth2 JWKS Views"""
 | 
				
			||||||
from base64 import urlsafe_b64encode
 | 
					from base64 import urlsafe_b64encode
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cryptography.hazmat.primitives.asymmetric.ec import (
 | 
					from cryptography.hazmat.primitives.asymmetric.ec import (
 | 
				
			||||||
    EllipticCurvePrivateKey,
 | 
					    EllipticCurvePrivateKey,
 | 
				
			||||||
@ -26,8 +27,37 @@ def b64_enc(number: int) -> str:
 | 
				
			|||||||
class JWKSView(View):
 | 
					class JWKSView(View):
 | 
				
			||||||
    """Show RSA Key data for Provider"""
 | 
					    """Show RSA Key data for Provider"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_jwk_for_key(self, key: CertificateKeyPair) -> Optional[dict]:
 | 
				
			||||||
 | 
					        """Convert a certificate-key pair into JWK"""
 | 
				
			||||||
 | 
					        private_key = key.private_key
 | 
				
			||||||
 | 
					        if not private_key:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        if isinstance(private_key, RSAPrivateKey):
 | 
				
			||||||
 | 
					            public_key: RSAPublicKey = private_key.public_key()
 | 
				
			||||||
 | 
					            public_numbers = public_key.public_numbers()
 | 
				
			||||||
 | 
					            return {
 | 
				
			||||||
 | 
					                "kty": "RSA",
 | 
				
			||||||
 | 
					                "alg": JWTAlgorithms.RS256,
 | 
				
			||||||
 | 
					                "use": "sig",
 | 
				
			||||||
 | 
					                "kid": key.kid,
 | 
				
			||||||
 | 
					                "n": b64_enc(public_numbers.n),
 | 
				
			||||||
 | 
					                "e": b64_enc(public_numbers.e),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        if isinstance(private_key, EllipticCurvePrivateKey):
 | 
				
			||||||
 | 
					            public_key: EllipticCurvePublicKey = private_key.public_key()
 | 
				
			||||||
 | 
					            public_numbers = public_key.public_numbers()
 | 
				
			||||||
 | 
					            return {
 | 
				
			||||||
 | 
					                "kty": "EC",
 | 
				
			||||||
 | 
					                "alg": JWTAlgorithms.ES256,
 | 
				
			||||||
 | 
					                "use": "sig",
 | 
				
			||||||
 | 
					                "kid": key.kid,
 | 
				
			||||||
 | 
					                "n": b64_enc(public_numbers.n),
 | 
				
			||||||
 | 
					                "e": b64_enc(public_numbers.e),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request: HttpRequest, application_slug: str) -> HttpResponse:
 | 
					    def get(self, request: HttpRequest, application_slug: str) -> HttpResponse:
 | 
				
			||||||
        """Show RSA Key data for Provider"""
 | 
					        """Show JWK Key data for Provider"""
 | 
				
			||||||
        application = get_object_or_404(Application, slug=application_slug)
 | 
					        application = get_object_or_404(Application, slug=application_slug)
 | 
				
			||||||
        provider: OAuth2Provider = get_object_or_404(OAuth2Provider, pk=application.provider_id)
 | 
					        provider: OAuth2Provider = get_object_or_404(OAuth2Provider, pk=application.provider_id)
 | 
				
			||||||
        signing_key: CertificateKeyPair = provider.signing_key
 | 
					        signing_key: CertificateKeyPair = provider.signing_key
 | 
				
			||||||
@ -35,33 +65,9 @@ class JWKSView(View):
 | 
				
			|||||||
        response_data = {}
 | 
					        response_data = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if signing_key:
 | 
					        if signing_key:
 | 
				
			||||||
            private_key = signing_key.private_key
 | 
					            jwk = self.get_jwk_for_key(signing_key)
 | 
				
			||||||
            if isinstance(private_key, RSAPrivateKey):
 | 
					            if jwk:
 | 
				
			||||||
                public_key: RSAPublicKey = private_key.public_key()
 | 
					                response_data["keys"] = [jwk]
 | 
				
			||||||
                public_numbers = public_key.public_numbers()
 | 
					 | 
				
			||||||
                response_data["keys"] = [
 | 
					 | 
				
			||||||
                    {
 | 
					 | 
				
			||||||
                        "kty": "RSA",
 | 
					 | 
				
			||||||
                        "alg": JWTAlgorithms.RS256,
 | 
					 | 
				
			||||||
                        "use": "sig",
 | 
					 | 
				
			||||||
                        "kid": signing_key.kid,
 | 
					 | 
				
			||||||
                        "n": b64_enc(public_numbers.n),
 | 
					 | 
				
			||||||
                        "e": b64_enc(public_numbers.e),
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
            elif isinstance(private_key, EllipticCurvePrivateKey):
 | 
					 | 
				
			||||||
                public_key: EllipticCurvePublicKey = private_key.public_key()
 | 
					 | 
				
			||||||
                public_numbers = public_key.public_numbers()
 | 
					 | 
				
			||||||
                response_data["keys"] = [
 | 
					 | 
				
			||||||
                    {
 | 
					 | 
				
			||||||
                        "kty": "EC",
 | 
					 | 
				
			||||||
                        "alg": JWTAlgorithms.ES256,
 | 
					 | 
				
			||||||
                        "use": "sig",
 | 
					 | 
				
			||||||
                        "kid": signing_key.kid,
 | 
					 | 
				
			||||||
                        "n": b64_enc(public_numbers.n),
 | 
					 | 
				
			||||||
                        "e": b64_enc(public_numbers.e),
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        response = JsonResponse(response_data)
 | 
					        response = JsonResponse(response_data)
 | 
				
			||||||
        response["Access-Control-Allow-Origin"] = "*"
 | 
					        response["Access-Control-Allow-Origin"] = "*"
 | 
				
			||||||
 | 
				
			|||||||
@ -9,7 +9,7 @@ from typing import Any, Optional
 | 
				
			|||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from django.utils.timezone import datetime, now
 | 
					from django.utils.timezone import datetime, now
 | 
				
			||||||
from django.views import View
 | 
					from django.views import View
 | 
				
			||||||
from jwt import InvalidTokenError, decode
 | 
					from jwt import PyJWK, PyJWTError, decode
 | 
				
			||||||
from sentry_sdk.hub import Hub
 | 
					from sentry_sdk.hub import Hub
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -43,6 +43,7 @@ from authentik.providers.oauth2.models import (
 | 
				
			|||||||
    RefreshToken,
 | 
					    RefreshToken,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth
 | 
					from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth
 | 
				
			||||||
 | 
					from authentik.sources.oauth.models import OAuthSource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -127,7 +128,7 @@ class TokenParams:
 | 
				
			|||||||
            with Hub.current.start_span(
 | 
					            with Hub.current.start_span(
 | 
				
			||||||
                op="authentik.providers.oauth2.post.parse.code",
 | 
					                op="authentik.providers.oauth2.post.parse.code",
 | 
				
			||||||
            ):
 | 
					            ):
 | 
				
			||||||
                self.__post_init_code(raw_code)
 | 
					                self.__post_init_code(raw_code, request)
 | 
				
			||||||
        elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN:
 | 
					        elif self.grant_type == GRANT_TYPE_REFRESH_TOKEN:
 | 
				
			||||||
            with Hub.current.start_span(
 | 
					            with Hub.current.start_span(
 | 
				
			||||||
                op="authentik.providers.oauth2.post.parse.refresh",
 | 
					                op="authentik.providers.oauth2.post.parse.refresh",
 | 
				
			||||||
@ -142,7 +143,7 @@ class TokenParams:
 | 
				
			|||||||
            LOGGER.warning("Invalid grant type", grant_type=self.grant_type)
 | 
					            LOGGER.warning("Invalid grant type", grant_type=self.grant_type)
 | 
				
			||||||
            raise TokenError("unsupported_grant_type")
 | 
					            raise TokenError("unsupported_grant_type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __post_init_code(self, raw_code: str):
 | 
					    def __post_init_code(self, raw_code: str, request: HttpRequest):
 | 
				
			||||||
        if not raw_code:
 | 
					        if not raw_code:
 | 
				
			||||||
            LOGGER.warning("Missing authorization code")
 | 
					            LOGGER.warning("Missing authorization code")
 | 
				
			||||||
            raise TokenError("invalid_grant")
 | 
					            raise TokenError("invalid_grant")
 | 
				
			||||||
@ -155,11 +156,23 @@ class TokenParams:
 | 
				
			|||||||
                LOGGER.warning(
 | 
					                LOGGER.warning(
 | 
				
			||||||
                    "Invalid redirect uri",
 | 
					                    "Invalid redirect uri",
 | 
				
			||||||
                    redirect_uri=self.redirect_uri,
 | 
					                    redirect_uri=self.redirect_uri,
 | 
				
			||||||
                    excepted=allowed_redirect_urls,
 | 
					                    expected=allowed_redirect_urls,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					                Event.new(
 | 
				
			||||||
 | 
					                    EventAction.CONFIGURATION_ERROR,
 | 
				
			||||||
 | 
					                    message="Invalid redirect URI used by provider",
 | 
				
			||||||
 | 
					                    provider=self.provider,
 | 
				
			||||||
 | 
					                    redirect_uri=self.redirect_uri,
 | 
				
			||||||
 | 
					                    expected=allowed_redirect_urls,
 | 
				
			||||||
 | 
					                ).from_http(request)
 | 
				
			||||||
                raise TokenError("invalid_client")
 | 
					                raise TokenError("invalid_client")
 | 
				
			||||||
        except RegexError as exc:
 | 
					        except RegexError as exc:
 | 
				
			||||||
            LOGGER.warning("Invalid regular expression configured", exc=exc)
 | 
					            LOGGER.warning("Invalid regular expression configured", exc=exc)
 | 
				
			||||||
 | 
					            Event.new(
 | 
				
			||||||
 | 
					                EventAction.CONFIGURATION_ERROR,
 | 
				
			||||||
 | 
					                message="Invalid redirect_uri RegEx configured",
 | 
				
			||||||
 | 
					                provider=self.provider,
 | 
				
			||||||
 | 
					            ).from_http(request)
 | 
				
			||||||
            raise TokenError("invalid_client")
 | 
					            raise TokenError("invalid_client")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
@ -258,17 +271,22 @@ class TokenParams:
 | 
				
			|||||||
        ).from_http(request, user=user)
 | 
					        ).from_http(request, user=user)
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # pylint: disable=too-many-locals
 | 
				
			||||||
    def __post_init_client_credentials_jwt(self, request: HttpRequest):
 | 
					    def __post_init_client_credentials_jwt(self, request: HttpRequest):
 | 
				
			||||||
        assertion_type = request.POST.get(CLIENT_ASSERTION_TYPE, "")
 | 
					        assertion_type = request.POST.get(CLIENT_ASSERTION_TYPE, "")
 | 
				
			||||||
        if assertion_type != CLIENT_ASSERTION_TYPE_JWT:
 | 
					        if assertion_type != CLIENT_ASSERTION_TYPE_JWT:
 | 
				
			||||||
 | 
					            LOGGER.warning("Invalid assertion type", assertion_type=assertion_type)
 | 
				
			||||||
            raise TokenError("invalid_grant")
 | 
					            raise TokenError("invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        client_secret = request.POST.get("client_secret", None)
 | 
					        client_secret = request.POST.get("client_secret", None)
 | 
				
			||||||
        assertion = request.POST.get(CLIENT_ASSERTION, client_secret)
 | 
					        assertion = request.POST.get(CLIENT_ASSERTION, client_secret)
 | 
				
			||||||
        if not assertion:
 | 
					        if not assertion:
 | 
				
			||||||
 | 
					            LOGGER.warning("Missing client assertion")
 | 
				
			||||||
            raise TokenError("invalid_grant")
 | 
					            raise TokenError("invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        token = None
 | 
					        token = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO: Remove in 2022.7, deprecated field `verification_keys``
 | 
				
			||||||
        for cert in self.provider.verification_keys.all():
 | 
					        for cert in self.provider.verification_keys.all():
 | 
				
			||||||
            LOGGER.debug("verifying jwt with key", key=cert.name)
 | 
					            LOGGER.debug("verifying jwt with key", key=cert.name)
 | 
				
			||||||
            cert: CertificateKeyPair
 | 
					            cert: CertificateKeyPair
 | 
				
			||||||
@ -284,9 +302,34 @@ class TokenParams:
 | 
				
			|||||||
                        "verify_aud": False,
 | 
					                        "verify_aud": False,
 | 
				
			||||||
                    },
 | 
					                    },
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            except (InvalidTokenError, ValueError, TypeError) as last_exc:
 | 
					            except (PyJWTError, ValueError, TypeError) as exc:
 | 
				
			||||||
                LOGGER.warning("failed to validate jwt", last_exc=last_exc)
 | 
					                LOGGER.warning("failed to validate jwt", exc=exc)
 | 
				
			||||||
 | 
					        # TODO: End remove block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        source: Optional[OAuthSource] = None
 | 
				
			||||||
 | 
					        parsed_key: Optional[PyJWK] = None
 | 
				
			||||||
 | 
					        for source in self.provider.jwks_sources.all():
 | 
				
			||||||
 | 
					            LOGGER.debug("verifying jwt with source", source=source.name)
 | 
				
			||||||
 | 
					            keys = source.oidc_jwks.get("keys", [])
 | 
				
			||||||
 | 
					            for key in keys:
 | 
				
			||||||
 | 
					                LOGGER.debug("verifying jwt with key", source=source.name, key=key.get("kid"))
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    parsed_key = PyJWK.from_dict(key)
 | 
				
			||||||
 | 
					                    token = decode(
 | 
				
			||||||
 | 
					                        assertion,
 | 
				
			||||||
 | 
					                        parsed_key.key,
 | 
				
			||||||
 | 
					                        algorithms=[key.get("alg")],
 | 
				
			||||||
 | 
					                        options={
 | 
				
			||||||
 | 
					                            "verify_aud": False,
 | 
				
			||||||
 | 
					                        },
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                # AttributeError is raised when the configured JWK is a private key
 | 
				
			||||||
 | 
					                # and not a public key
 | 
				
			||||||
 | 
					                except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
 | 
				
			||||||
 | 
					                    LOGGER.warning("failed to validate jwt", exc=exc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not token:
 | 
					        if not token:
 | 
				
			||||||
 | 
					            LOGGER.warning("No token could be verified")
 | 
				
			||||||
            raise TokenError("invalid_grant")
 | 
					            raise TokenError("invalid_grant")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if "exp" in token:
 | 
					        if "exp" in token:
 | 
				
			||||||
@ -304,12 +347,17 @@ class TokenParams:
 | 
				
			|||||||
        self.__check_policy_access(app, request, oauth_jwt=token)
 | 
					        self.__check_policy_access(app, request, oauth_jwt=token)
 | 
				
			||||||
        self.__create_user_from_jwt(token, app)
 | 
					        self.__create_user_from_jwt(token, app)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        method_args = {
 | 
				
			||||||
 | 
					            "jwt": token,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if source:
 | 
				
			||||||
 | 
					            method_args["source"] = source
 | 
				
			||||||
 | 
					        if parsed_key:
 | 
				
			||||||
 | 
					            method_args["jwk_id"] = parsed_key.key_id
 | 
				
			||||||
        Event.new(
 | 
					        Event.new(
 | 
				
			||||||
            action=EventAction.LOGIN,
 | 
					            action=EventAction.LOGIN,
 | 
				
			||||||
            PLAN_CONTEXT_METHOD="jwt",
 | 
					            PLAN_CONTEXT_METHOD="jwt",
 | 
				
			||||||
            PLAN_CONTEXT_METHOD_ARGS={
 | 
					            PLAN_CONTEXT_METHOD_ARGS=method_args,
 | 
				
			||||||
                "jwt": token,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            PLAN_CONTEXT_APPLICATION=app,
 | 
					            PLAN_CONTEXT_APPLICATION=app,
 | 
				
			||||||
        ).from_http(request, user=self.user)
 | 
					        ).from_http(request, user=self.user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -12,8 +12,4 @@ class AuthentikProviderProxyConfig(AppConfig):
 | 
				
			|||||||
    verbose_name = "authentik Providers.Proxy"
 | 
					    verbose_name = "authentik Providers.Proxy"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def ready(self) -> None:
 | 
					    def ready(self) -> None:
 | 
				
			||||||
        from authentik.providers.proxy.tasks import proxy_set_defaults
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        import_module("authentik.providers.proxy.managed")
 | 
					        import_module("authentik.providers.proxy.managed")
 | 
				
			||||||
 | 
					 | 
				
			||||||
        proxy_set_defaults.delay()
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -52,7 +52,7 @@ class SAMLProvider(Provider):
 | 
				
			|||||||
        default=SAMLBindings.REDIRECT,
 | 
					        default=SAMLBindings.REDIRECT,
 | 
				
			||||||
        verbose_name=_("Service Provider Binding"),
 | 
					        verbose_name=_("Service Provider Binding"),
 | 
				
			||||||
        help_text=_(
 | 
					        help_text=_(
 | 
				
			||||||
            ("This determines how authentik sends the " "response back to the Service Provider.")
 | 
					            ("This determines how authentik sends the response back to the Service Provider.")
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -7,9 +7,9 @@ from xml.etree.ElementTree import ParseError  # nosec
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import xmlsec
 | 
					import xmlsec
 | 
				
			||||||
from defusedxml import ElementTree
 | 
					from defusedxml import ElementTree
 | 
				
			||||||
from lxml import etree  # nosec
 | 
					 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from authentik.lib.xml import lxml_from_string
 | 
				
			||||||
from authentik.providers.saml.exceptions import CannotHandleAssertion
 | 
					from authentik.providers.saml.exceptions import CannotHandleAssertion
 | 
				
			||||||
from authentik.providers.saml.models import SAMLProvider
 | 
					from authentik.providers.saml.models import SAMLProvider
 | 
				
			||||||
from authentik.providers.saml.utils.encoding import decode_base64_and_inflate
 | 
					from authentik.providers.saml.utils.encoding import decode_base64_and_inflate
 | 
				
			||||||
@ -95,7 +95,7 @@ class AuthNRequestParser:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        verifier = self.provider.verification_kp
 | 
					        verifier = self.provider.verification_kp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        root = etree.fromstring(decoded_xml)  # nosec
 | 
					        root = lxml_from_string(decoded_xml)
 | 
				
			||||||
        xmlsec.tree.add_ids(root, ["ID"])
 | 
					        xmlsec.tree.add_ids(root, ["ID"])
 | 
				
			||||||
        signature_nodes = root.xpath("/samlp:AuthnRequest/ds:Signature", namespaces=NS_MAP)
 | 
					        signature_nodes = root.xpath("/samlp:AuthnRequest/ds:Signature", namespaces=NS_MAP)
 | 
				
			||||||
        # No signatures, no verifier configured -> decode xml directly
 | 
					        # No signatures, no verifier configured -> decode xml directly
 | 
				
			||||||
 | 
				
			|||||||
@ -19,7 +19,7 @@ from authentik.sources.saml.processors.constants import (
 | 
				
			|||||||
    SAML_NAME_ID_FORMAT_EMAIL,
 | 
					    SAML_NAME_ID_FORMAT_EMAIL,
 | 
				
			||||||
    SAML_NAME_ID_FORMAT_UNSPECIFIED,
 | 
					    SAML_NAME_ID_FORMAT_UNSPECIFIED,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.sources.saml.processors.request import SESSION_REQUEST_ID, RequestProcessor
 | 
					from authentik.sources.saml.processors.request import SESSION_KEY_REQUEST_ID, RequestProcessor
 | 
				
			||||||
from authentik.sources.saml.processors.response import ResponseProcessor
 | 
					from authentik.sources.saml.processors.response import ResponseProcessor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
POST_REQUEST = (
 | 
					POST_REQUEST = (
 | 
				
			||||||
@ -142,7 +142,7 @@ class TestAuthNRequest(TestCase):
 | 
				
			|||||||
        request = request_proc.build_auth_n()
 | 
					        request = request_proc.build_auth_n()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # change the request ID
 | 
					        # change the request ID
 | 
				
			||||||
        http_request.session[SESSION_REQUEST_ID] = "test"
 | 
					        http_request.session[SESSION_KEY_REQUEST_ID] = "test"
 | 
				
			||||||
        http_request.session.save()
 | 
					        http_request.session.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # To get an assertion we need a parsed request (parsed by provider)
 | 
					        # To get an assertion we need a parsed request (parsed by provider)
 | 
				
			||||||
 | 
				
			|||||||
@ -6,6 +6,7 @@ from lxml import etree  # nosec
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
				
			||||||
from authentik.lib.tests.utils import get_request
 | 
					from authentik.lib.tests.utils import get_request
 | 
				
			||||||
 | 
					from authentik.lib.xml import lxml_from_string
 | 
				
			||||||
from authentik.managed.manager import ObjectManager
 | 
					from authentik.managed.manager import ObjectManager
 | 
				
			||||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
 | 
					from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
 | 
				
			||||||
from authentik.providers.saml.processors.assertion import AssertionProcessor
 | 
					from authentik.providers.saml.processors.assertion import AssertionProcessor
 | 
				
			||||||
@ -44,7 +45,7 @@ class TestSchema(TestCase):
 | 
				
			|||||||
        request_proc = RequestProcessor(self.source, http_request, "test_state")
 | 
					        request_proc = RequestProcessor(self.source, http_request, "test_state")
 | 
				
			||||||
        request = request_proc.build_auth_n()
 | 
					        request = request_proc.build_auth_n()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        metadata = etree.fromstring(request)  # nosec
 | 
					        metadata = lxml_from_string(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        schema = etree.XMLSchema(etree.parse("xml/saml-schema-protocol-2.0.xsd"))  # nosec
 | 
					        schema = etree.XMLSchema(etree.parse("xml/saml-schema-protocol-2.0.xsd"))  # nosec
 | 
				
			||||||
        self.assertTrue(schema.validate(metadata))
 | 
					        self.assertTrue(schema.validate(metadata))
 | 
				
			||||||
@ -65,7 +66,7 @@ class TestSchema(TestCase):
 | 
				
			|||||||
        response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
 | 
					        response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
 | 
				
			||||||
        response = response_proc.build_response()
 | 
					        response = response_proc.build_response()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        metadata = etree.fromstring(response)  # nosec
 | 
					        metadata = lxml_from_string(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        schema = etree.XMLSchema(etree.parse("xml/saml-schema-protocol-2.0.xsd"))
 | 
					        schema = etree.XMLSchema(etree.parse("xml/saml-schema-protocol-2.0.xsd"))
 | 
				
			||||||
        self.assertTrue(schema.validate(metadata))
 | 
					        self.assertTrue(schema.validate(metadata))
 | 
				
			||||||
 | 
				
			|||||||
@ -34,7 +34,7 @@ REQUEST_KEY_SAML_SIG_ALG = "SigAlg"
 | 
				
			|||||||
REQUEST_KEY_SAML_RESPONSE = "SAMLResponse"
 | 
					REQUEST_KEY_SAML_RESPONSE = "SAMLResponse"
 | 
				
			||||||
REQUEST_KEY_RELAY_STATE = "RelayState"
 | 
					REQUEST_KEY_RELAY_STATE = "RelayState"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SESSION_KEY_AUTH_N_REQUEST = "authn_request"
 | 
					SESSION_KEY_AUTH_N_REQUEST = "authentik/providers/saml/authn_request"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# This View doesn't have a URL on purpose, as its called by the FlowExecutor
 | 
					# This View doesn't have a URL on purpose, as its called by the FlowExecutor
 | 
				
			||||||
class SAMLFlowFinalView(ChallengeStageView):
 | 
					class SAMLFlowFinalView(ChallengeStageView):
 | 
				
			||||||
@ -106,3 +106,6 @@ class SAMLFlowFinalView(ChallengeStageView):
 | 
				
			|||||||
    def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
 | 
					    def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
 | 
				
			||||||
        # We'll never get here since the challenge redirects to the SP
 | 
					        # We'll never get here since the challenge redirects to the SP
 | 
				
			||||||
        return HttpResponseBadRequest()
 | 
					        return HttpResponseBadRequest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def cleanup(self):
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST, None)
 | 
				
			||||||
 | 
				
			|||||||
@ -10,8 +10,10 @@ from celery.signals import (
 | 
				
			|||||||
    task_internal_error,
 | 
					    task_internal_error,
 | 
				
			||||||
    task_postrun,
 | 
					    task_postrun,
 | 
				
			||||||
    task_prerun,
 | 
					    task_prerun,
 | 
				
			||||||
 | 
					    worker_ready,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.db import ProgrammingError
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.middleware import LOCAL
 | 
					from authentik.core.middleware import LOCAL
 | 
				
			||||||
@ -74,6 +76,29 @@ def task_error_hook(task_id, exception: Exception, traceback, *args, **kwargs):
 | 
				
			|||||||
        Event.new(EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception)).save()
 | 
					        Event.new(EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception)).save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@worker_ready.connect
 | 
				
			||||||
 | 
					def worker_ready_hook(*args, **kwargs):
 | 
				
			||||||
 | 
					    """Run certain tasks on worker start"""
 | 
				
			||||||
 | 
					    from authentik.admin.tasks import clear_update_notifications
 | 
				
			||||||
 | 
					    from authentik.managed.tasks import managed_reconcile
 | 
				
			||||||
 | 
					    from authentik.outposts.tasks import outpost_controller_all, outpost_local_connection
 | 
				
			||||||
 | 
					    from authentik.providers.proxy.tasks import proxy_set_defaults
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tasks = [
 | 
				
			||||||
 | 
					        clear_update_notifications,
 | 
				
			||||||
 | 
					        outpost_local_connection,
 | 
				
			||||||
 | 
					        outpost_controller_all,
 | 
				
			||||||
 | 
					        proxy_set_defaults,
 | 
				
			||||||
 | 
					        managed_reconcile,
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    LOGGER.info("Dispatching startup tasks...")
 | 
				
			||||||
 | 
					    for task in tasks:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            task.delay()
 | 
				
			||||||
 | 
					        except ProgrammingError as exc:
 | 
				
			||||||
 | 
					            LOGGER.warning("Startup task failed", task=task, exc=exc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Using a string here means the worker doesn't have to serialize
 | 
					# Using a string here means the worker doesn't have to serialize
 | 
				
			||||||
# the configuration object to child processes.
 | 
					# the configuration object to child processes.
 | 
				
			||||||
# - namespace='CELERY' means all celery-related configuration keys
 | 
					# - namespace='CELERY' means all celery-related configuration keys
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,5 @@
 | 
				
			|||||||
"""Dynamically set SameSite depending if the upstream connection is TLS or not"""
 | 
					"""Dynamically set SameSite depending if the upstream connection is TLS or not"""
 | 
				
			||||||
 | 
					from hashlib import sha512
 | 
				
			||||||
from time import time
 | 
					from time import time
 | 
				
			||||||
from typing import Callable
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -10,11 +11,14 @@ from django.http.request import HttpRequest
 | 
				
			|||||||
from django.http.response import HttpResponse
 | 
					from django.http.response import HttpResponse
 | 
				
			||||||
from django.utils.cache import patch_vary_headers
 | 
					from django.utils.cache import patch_vary_headers
 | 
				
			||||||
from django.utils.http import http_date
 | 
					from django.utils.http import http_date
 | 
				
			||||||
 | 
					from jwt import PyJWTError, decode, encode
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.lib.utils.http import get_client_ip
 | 
					from authentik.lib.utils.http import get_client_ip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger("authentik.asgi")
 | 
					LOGGER = get_logger("authentik.asgi")
 | 
				
			||||||
 | 
					ACR_AUTHENTIK_SESSION = "goauthentik.io/core/default"
 | 
				
			||||||
 | 
					SIGNING_HASH = sha512(settings.SECRET_KEY.encode()).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SessionMiddleware(UpstreamSessionMiddleware):
 | 
					class SessionMiddleware(UpstreamSessionMiddleware):
 | 
				
			||||||
@ -35,6 +39,18 @@ class SessionMiddleware(UpstreamSessionMiddleware):
 | 
				
			|||||||
            return True
 | 
					            return True
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_request(self, request):
 | 
				
			||||||
 | 
					        session_jwt = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
 | 
				
			||||||
 | 
					        # We need to support the standard django format of just a session key
 | 
				
			||||||
 | 
					        # for testing setups, where the session is directly set
 | 
				
			||||||
 | 
					        session_key = session_jwt if settings.TEST else None
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            session_payload = decode(session_jwt, SIGNING_HASH, algorithms=["HS256"])
 | 
				
			||||||
 | 
					            session_key = session_payload["sid"]
 | 
				
			||||||
 | 
					        except (KeyError, PyJWTError):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        request.session = self.SessionStore(session_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
 | 
					    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        If request.session was modified, or if the configuration is to save the
 | 
					        If request.session was modified, or if the configuration is to save the
 | 
				
			||||||
@ -82,9 +98,21 @@ class SessionMiddleware(UpstreamSessionMiddleware):
 | 
				
			|||||||
                            "request completed. The user may have logged "
 | 
					                            "request completed. The user may have logged "
 | 
				
			||||||
                            "out in a concurrent request, for example."
 | 
					                            "out in a concurrent request, for example."
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					                    payload = {
 | 
				
			||||||
 | 
					                        "sid": request.session.session_key,
 | 
				
			||||||
 | 
					                        "iss": "authentik",
 | 
				
			||||||
 | 
					                        "sub": "anonymous",
 | 
				
			||||||
 | 
					                        "authenticated": request.user.is_authenticated,
 | 
				
			||||||
 | 
					                        "acr": ACR_AUTHENTIK_SESSION,
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    if request.user.is_authenticated:
 | 
				
			||||||
 | 
					                        payload["sub"] = request.user.uid
 | 
				
			||||||
 | 
					                    value = encode(payload=payload, key=SIGNING_HASH)
 | 
				
			||||||
 | 
					                    if settings.TEST:
 | 
				
			||||||
 | 
					                        value = request.session.session_key
 | 
				
			||||||
                    response.set_cookie(
 | 
					                    response.set_cookie(
 | 
				
			||||||
                        settings.SESSION_COOKIE_NAME,
 | 
					                        settings.SESSION_COOKIE_NAME,
 | 
				
			||||||
                        request.session.session_key,
 | 
					                        value,
 | 
				
			||||||
                        max_age=max_age,
 | 
					                        max_age=max_age,
 | 
				
			||||||
                        expires=expires,
 | 
					                        expires=expires,
 | 
				
			||||||
                        domain=settings.SESSION_COOKIE_DOMAIN,
 | 
					                        domain=settings.SESSION_COOKIE_DOMAIN,
 | 
				
			||||||
 | 
				
			|||||||
@ -147,12 +147,12 @@ SPECTACULAR_SETTINGS = {
 | 
				
			|||||||
        },
 | 
					        },
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    "CONTACT": {
 | 
					    "CONTACT": {
 | 
				
			||||||
        "email": "hello@beryju.org",
 | 
					        "email": "hello@goauthentik.io",
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    "AUTHENTICATION_WHITELIST": ["authentik.api.authentication.TokenAuthentication"],
 | 
					    "AUTHENTICATION_WHITELIST": ["authentik.api.authentication.TokenAuthentication"],
 | 
				
			||||||
    "LICENSE": {
 | 
					    "LICENSE": {
 | 
				
			||||||
        "name": "GNU GPLv3",
 | 
					        "name": "GNU GPLv3",
 | 
				
			||||||
        "url": "https://github.com/goauthentik/authentik/blob/master/LICENSE",
 | 
					        "url": "https://github.com/goauthentik/authentik/blob/main/LICENSE",
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    "ENUM_NAME_OVERRIDES": {
 | 
					    "ENUM_NAME_OVERRIDES": {
 | 
				
			||||||
        "EventActions": "authentik.events.models.EventAction",
 | 
					        "EventActions": "authentik.events.models.EventAction",
 | 
				
			||||||
@ -217,12 +217,12 @@ DJANGO_REDIS_SCAN_ITERSIZE = 1000
 | 
				
			|||||||
DJANGO_REDIS_IGNORE_EXCEPTIONS = True
 | 
					DJANGO_REDIS_IGNORE_EXCEPTIONS = True
 | 
				
			||||||
DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True
 | 
					DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True
 | 
				
			||||||
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
 | 
					SESSION_ENGINE = "django.contrib.sessions.backends.cache"
 | 
				
			||||||
 | 
					SESSION_SERIALIZER = "django.contrib.sessions.serializers.PickleSerializer"
 | 
				
			||||||
SESSION_CACHE_ALIAS = "default"
 | 
					SESSION_CACHE_ALIAS = "default"
 | 
				
			||||||
# Configured via custom SessionMiddleware
 | 
					# Configured via custom SessionMiddleware
 | 
				
			||||||
# SESSION_COOKIE_SAMESITE = "None"
 | 
					# SESSION_COOKIE_SAMESITE = "None"
 | 
				
			||||||
# SESSION_COOKIE_SECURE = True
 | 
					# SESSION_COOKIE_SECURE = True
 | 
				
			||||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
 | 
					SESSION_EXPIRE_AT_BROWSER_CLOSE = True
 | 
				
			||||||
SESSION_SAVE_EVERY_REQUEST = True
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
 | 
					MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -408,12 +408,12 @@ LOGGING = {
 | 
				
			|||||||
    "version": 1,
 | 
					    "version": 1,
 | 
				
			||||||
    "disable_existing_loggers": False,
 | 
					    "disable_existing_loggers": False,
 | 
				
			||||||
    "formatters": {
 | 
					    "formatters": {
 | 
				
			||||||
        "plain": {
 | 
					        "json": {
 | 
				
			||||||
            "()": structlog.stdlib.ProcessorFormatter,
 | 
					            "()": structlog.stdlib.ProcessorFormatter,
 | 
				
			||||||
            "processor": structlog.processors.JSONRenderer(sort_keys=True),
 | 
					            "processor": structlog.processors.JSONRenderer(sort_keys=True),
 | 
				
			||||||
            "foreign_pre_chain": LOG_PRE_CHAIN,
 | 
					            "foreign_pre_chain": LOG_PRE_CHAIN,
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
        "colored": {
 | 
					        "console": {
 | 
				
			||||||
            "()": structlog.stdlib.ProcessorFormatter,
 | 
					            "()": structlog.stdlib.ProcessorFormatter,
 | 
				
			||||||
            "processor": structlog.dev.ConsoleRenderer(colors=DEBUG),
 | 
					            "processor": structlog.dev.ConsoleRenderer(colors=DEBUG),
 | 
				
			||||||
            "foreign_pre_chain": LOG_PRE_CHAIN,
 | 
					            "foreign_pre_chain": LOG_PRE_CHAIN,
 | 
				
			||||||
@ -423,7 +423,7 @@ LOGGING = {
 | 
				
			|||||||
        "console": {
 | 
					        "console": {
 | 
				
			||||||
            "level": "DEBUG",
 | 
					            "level": "DEBUG",
 | 
				
			||||||
            "class": "logging.StreamHandler",
 | 
					            "class": "logging.StreamHandler",
 | 
				
			||||||
            "formatter": "colored" if DEBUG else "plain",
 | 
					            "formatter": "console" if DEBUG else "json",
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    "loggers": {},
 | 
					    "loggers": {},
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@
 | 
				
			|||||||
from django.urls.base import reverse_lazy
 | 
					from django.urls.base import reverse_lazy
 | 
				
			||||||
from drf_spectacular.types import OpenApiTypes
 | 
					from drf_spectacular.types import OpenApiTypes
 | 
				
			||||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_field
 | 
					from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_field
 | 
				
			||||||
 | 
					from requests import RequestException
 | 
				
			||||||
from rest_framework.decorators import action
 | 
					from rest_framework.decorators import action
 | 
				
			||||||
from rest_framework.fields import BooleanField, CharField, ChoiceField, SerializerMethodField
 | 
					from rest_framework.fields import BooleanField, CharField, ChoiceField, SerializerMethodField
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
@ -12,6 +13,7 @@ from rest_framework.viewsets import ModelViewSet
 | 
				
			|||||||
from authentik.core.api.sources import SourceSerializer
 | 
					from authentik.core.api.sources import SourceSerializer
 | 
				
			||||||
from authentik.core.api.used_by import UsedByMixin
 | 
					from authentik.core.api.used_by import UsedByMixin
 | 
				
			||||||
from authentik.core.api.utils import PassiveSerializer
 | 
					from authentik.core.api.utils import PassiveSerializer
 | 
				
			||||||
 | 
					from authentik.lib.utils.http import get_http_session
 | 
				
			||||||
from authentik.sources.oauth.models import OAuthSource
 | 
					from authentik.sources.oauth.models import OAuthSource
 | 
				
			||||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
 | 
					from authentik.sources.oauth.types.manager import MANAGER, SourceType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -52,6 +54,33 @@ class OAuthSourceSerializer(SourceSerializer):
 | 
				
			|||||||
        return SourceTypeSerializer(instance.type).data
 | 
					        return SourceTypeSerializer(instance.type).data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate(self, attrs: dict) -> dict:
 | 
					    def validate(self, attrs: dict) -> dict:
 | 
				
			||||||
 | 
					        session = get_http_session()
 | 
				
			||||||
 | 
					        well_known = attrs.get("oidc_well_known_url")
 | 
				
			||||||
 | 
					        if well_known and well_known != "":
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                well_known_config = session.get(well_known)
 | 
				
			||||||
 | 
					                well_known_config.raise_for_status()
 | 
				
			||||||
 | 
					            except RequestException as exc:
 | 
				
			||||||
 | 
					                raise ValidationError(exc.response.text)
 | 
				
			||||||
 | 
					            config = well_known_config.json()
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                attrs["authorization_url"] = config["authorization_endpoint"]
 | 
				
			||||||
 | 
					                attrs["access_token_url"] = config["token_endpoint"]
 | 
				
			||||||
 | 
					                attrs["profile_url"] = config["userinfo_endpoint"]
 | 
				
			||||||
 | 
					                attrs["oidc_jwks_url"] = config["jwks_uri"]
 | 
				
			||||||
 | 
					            except (IndexError, KeyError) as exc:
 | 
				
			||||||
 | 
					                raise ValidationError(f"Invalid well-known configuration: {exc}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        jwks_url = attrs.get("oidc_jwks_url")
 | 
				
			||||||
 | 
					        if jwks_url and jwks_url != "":
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                jwks_config = session.get(jwks_url)
 | 
				
			||||||
 | 
					                jwks_config.raise_for_status()
 | 
				
			||||||
 | 
					            except RequestException as exc:
 | 
				
			||||||
 | 
					                raise ValidationError(exc.response.text)
 | 
				
			||||||
 | 
					            config = jwks_config.json()
 | 
				
			||||||
 | 
					            attrs["oidc_jwks"] = config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        provider_type = MANAGER.find_type(attrs.get("provider_type", ""))
 | 
					        provider_type = MANAGER.find_type(attrs.get("provider_type", ""))
 | 
				
			||||||
        for url in [
 | 
					        for url in [
 | 
				
			||||||
            "authorization_url",
 | 
					            "authorization_url",
 | 
				
			||||||
@ -76,6 +105,9 @@ class OAuthSourceSerializer(SourceSerializer):
 | 
				
			|||||||
            "callback_url",
 | 
					            "callback_url",
 | 
				
			||||||
            "additional_scopes",
 | 
					            "additional_scopes",
 | 
				
			||||||
            "type",
 | 
					            "type",
 | 
				
			||||||
 | 
					            "oidc_well_known_url",
 | 
				
			||||||
 | 
					            "oidc_jwks_url",
 | 
				
			||||||
 | 
					            "oidc_jwks",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        extra_kwargs = {"consumer_secret": {"write_only": True}}
 | 
					        extra_kwargs = {"consumer_secret": {"write_only": True}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -34,6 +34,5 @@ class AuthentikSourceOAuthConfig(AppConfig):
 | 
				
			|||||||
        for source_type in AUTHENTIK_SOURCES_OAUTH_TYPES:
 | 
					        for source_type in AUTHENTIK_SOURCES_OAUTH_TYPES:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                import_module(source_type)
 | 
					                import_module(source_type)
 | 
				
			||||||
                LOGGER.debug("Loaded OAuth Source Type", type=source_type)
 | 
					 | 
				
			||||||
            except ImportError as exc:
 | 
					            except ImportError as exc:
 | 
				
			||||||
                LOGGER.debug(str(exc))
 | 
					                LOGGER.warning("Failed to load OAuth Source", exc=exc)
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,7 @@ from structlog.stdlib import get_logger
 | 
				
			|||||||
from authentik.sources.oauth.clients.base import BaseOAuthClient
 | 
					from authentik.sources.oauth.clients.base import BaseOAuthClient
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
SESSION_OAUTH_PKCE = "oauth_pkce"
 | 
					SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OAuth2Client(BaseOAuthClient):
 | 
					class OAuth2Client(BaseOAuthClient):
 | 
				
			||||||
@ -70,17 +70,14 @@ class OAuth2Client(BaseOAuthClient):
 | 
				
			|||||||
            "code": code,
 | 
					            "code": code,
 | 
				
			||||||
            "grant_type": "authorization_code",
 | 
					            "grant_type": "authorization_code",
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        if SESSION_OAUTH_PKCE in self.request.session:
 | 
					        if SESSION_KEY_OAUTH_PKCE in self.request.session:
 | 
				
			||||||
            args["code_verifier"] = self.request.session[SESSION_OAUTH_PKCE]
 | 
					            args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            access_token_url = self.source.type.access_token_url or ""
 | 
					            access_token_url = self.source.type.access_token_url or ""
 | 
				
			||||||
            if self.source.type.urls_customizable and self.source.access_token_url:
 | 
					            if self.source.type.urls_customizable and self.source.access_token_url:
 | 
				
			||||||
                access_token_url = self.source.access_token_url
 | 
					                access_token_url = self.source.access_token_url
 | 
				
			||||||
            response = self.session.request(
 | 
					            response = self.session.request(
 | 
				
			||||||
                "post",
 | 
					                "post", access_token_url, data=args, headers=self._default_headers, **request_kwargs
 | 
				
			||||||
                access_token_url,
 | 
					 | 
				
			||||||
                data=args,
 | 
					 | 
				
			||||||
                headers=self._default_headers,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            response.raise_for_status()
 | 
					            response.raise_for_status()
 | 
				
			||||||
        except RequestException as exc:
 | 
					        except RequestException as exc:
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,28 @@
 | 
				
			|||||||
 | 
					# Generated by Django 4.0.4 on 2022-05-23 20:17
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_sources_oauth", "0006_oauthsource_additional_scopes"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="oauthsource",
 | 
				
			||||||
 | 
					            name="oidc_jwks",
 | 
				
			||||||
 | 
					            field=models.JSONField(blank=True, default=dict),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="oauthsource",
 | 
				
			||||||
 | 
					            name="oidc_jwks_url",
 | 
				
			||||||
 | 
					            field=models.TextField(blank=True, default=""),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="oauthsource",
 | 
				
			||||||
 | 
					            name="oidc_well_known_url",
 | 
				
			||||||
 | 
					            field=models.TextField(blank=True, default=""),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -50,6 +50,10 @@ class OAuthSource(Source):
 | 
				
			|||||||
    consumer_key = models.TextField()
 | 
					    consumer_key = models.TextField()
 | 
				
			||||||
    consumer_secret = models.TextField()
 | 
					    consumer_secret = models.TextField()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    oidc_well_known_url = models.TextField(default="", blank=True)
 | 
				
			||||||
 | 
					    oidc_jwks_url = models.TextField(default="", blank=True)
 | 
				
			||||||
 | 
					    oidc_jwks = models.JSONField(default=dict, blank=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def type(self) -> type["SourceType"]:
 | 
					    def type(self) -> type["SourceType"]:
 | 
				
			||||||
        """Return the provider instance for this source"""
 | 
					        """Return the provider instance for this source"""
 | 
				
			||||||
 | 
				
			|||||||
@ -15,12 +15,12 @@ AAD_USER = {
 | 
				
			|||||||
    "displayName": "foo bar",
 | 
					    "displayName": "foo bar",
 | 
				
			||||||
    "givenName": "foo",
 | 
					    "givenName": "foo",
 | 
				
			||||||
    "jobTitle": None,
 | 
					    "jobTitle": None,
 | 
				
			||||||
    "mail": "foo@beryju.org",
 | 
					    "mail": "foo@goauthentik.io",
 | 
				
			||||||
    "mobilePhone": None,
 | 
					    "mobilePhone": None,
 | 
				
			||||||
    "officeLocation": None,
 | 
					    "officeLocation": None,
 | 
				
			||||||
    "preferredLanguage": None,
 | 
					    "preferredLanguage": None,
 | 
				
			||||||
    "surname": "bar",
 | 
					    "surname": "bar",
 | 
				
			||||||
    "userPrincipalName": "foo@beryju.org",
 | 
					    "userPrincipalName": "foo@goauthentik.io",
 | 
				
			||||||
    "id": "018b0aff-8aff-473e-bf9c-b50e27f52208",
 | 
					    "id": "018b0aff-8aff-473e-bf9c-b50e27f52208",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
"""OAuth Source tests"""
 | 
					"""OAuth Source tests"""
 | 
				
			||||||
from django.test import TestCase
 | 
					from django.test import TestCase
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
 | 
					from requests_mock import Mocker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.sources.oauth.api.source import OAuthSourceSerializer
 | 
					from authentik.sources.oauth.api.source import OAuthSourceSerializer
 | 
				
			||||||
from authentik.sources.oauth.models import OAuthSource
 | 
					from authentik.sources.oauth.models import OAuthSource
 | 
				
			||||||
@ -29,6 +30,8 @@ class TestOAuthSource(TestCase):
 | 
				
			|||||||
                    "provider_type": "google",
 | 
					                    "provider_type": "google",
 | 
				
			||||||
                    "consumer_key": "foo",
 | 
					                    "consumer_key": "foo",
 | 
				
			||||||
                    "consumer_secret": "foo",
 | 
					                    "consumer_secret": "foo",
 | 
				
			||||||
 | 
					                    "oidc_well_known_url": "",
 | 
				
			||||||
 | 
					                    "oidc_jwks_url": "",
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            ).is_valid()
 | 
					            ).is_valid()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -44,6 +47,70 @@ class TestOAuthSource(TestCase):
 | 
				
			|||||||
            ).is_valid()
 | 
					            ).is_valid()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_api_validate_openid_connect(self):
 | 
				
			||||||
 | 
					        """Test API validation (with OIDC endpoints)"""
 | 
				
			||||||
 | 
					        openid_config = {
 | 
				
			||||||
 | 
					            "authorization_endpoint": "http://mock/oauth/authorize",
 | 
				
			||||||
 | 
					            "token_endpoint": "http://mock/oauth/token",
 | 
				
			||||||
 | 
					            "userinfo_endpoint": "http://mock/oauth/userinfo",
 | 
				
			||||||
 | 
					            "jwks_uri": "http://mock/oauth/discovery/keys",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        jwks_config = {"keys": []}
 | 
				
			||||||
 | 
					        with Mocker() as mocker:
 | 
				
			||||||
 | 
					            url = "http://mock/.well-known/openid-configuration"
 | 
				
			||||||
 | 
					            mocker.get(url, json=openid_config)
 | 
				
			||||||
 | 
					            mocker.get(openid_config["jwks_uri"], json=jwks_config)
 | 
				
			||||||
 | 
					            serializer = OAuthSourceSerializer(
 | 
				
			||||||
 | 
					                instance=self.source,
 | 
				
			||||||
 | 
					                data={
 | 
				
			||||||
 | 
					                    "name": "foo",
 | 
				
			||||||
 | 
					                    "slug": "bar",
 | 
				
			||||||
 | 
					                    "provider_type": "openidconnect",
 | 
				
			||||||
 | 
					                    "consumer_key": "foo",
 | 
				
			||||||
 | 
					                    "consumer_secret": "foo",
 | 
				
			||||||
 | 
					                    "authorization_url": "http://foo",
 | 
				
			||||||
 | 
					                    "access_token_url": "http://foo",
 | 
				
			||||||
 | 
					                    "profile_url": "http://foo",
 | 
				
			||||||
 | 
					                    "oidc_well_known_url": url,
 | 
				
			||||||
 | 
					                    "oidc_jwks_url": "",
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertTrue(serializer.is_valid())
 | 
				
			||||||
 | 
					            self.assertEqual(
 | 
				
			||||||
 | 
					                serializer.validated_data["authorization_url"], "http://mock/oauth/authorize"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertEqual(
 | 
				
			||||||
 | 
					                serializer.validated_data["access_token_url"], "http://mock/oauth/token"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertEqual(serializer.validated_data["profile_url"], "http://mock/oauth/userinfo")
 | 
				
			||||||
 | 
					            self.assertEqual(
 | 
				
			||||||
 | 
					                serializer.validated_data["oidc_jwks_url"], "http://mock/oauth/discovery/keys"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertEqual(serializer.validated_data["oidc_jwks"], jwks_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_api_validate_openid_connect_invalid(self):
 | 
				
			||||||
 | 
					        """Test API validation (with OIDC endpoints)"""
 | 
				
			||||||
 | 
					        openid_config = {}
 | 
				
			||||||
 | 
					        with Mocker() as mocker:
 | 
				
			||||||
 | 
					            url = "http://mock/.well-known/openid-configuration"
 | 
				
			||||||
 | 
					            mocker.get(url, json=openid_config)
 | 
				
			||||||
 | 
					            serializer = OAuthSourceSerializer(
 | 
				
			||||||
 | 
					                instance=self.source,
 | 
				
			||||||
 | 
					                data={
 | 
				
			||||||
 | 
					                    "name": "foo",
 | 
				
			||||||
 | 
					                    "slug": "bar",
 | 
				
			||||||
 | 
					                    "provider_type": "openidconnect",
 | 
				
			||||||
 | 
					                    "consumer_key": "foo",
 | 
				
			||||||
 | 
					                    "consumer_secret": "foo",
 | 
				
			||||||
 | 
					                    "authorization_url": "http://foo",
 | 
				
			||||||
 | 
					                    "access_token_url": "http://foo",
 | 
				
			||||||
 | 
					                    "profile_url": "http://foo",
 | 
				
			||||||
 | 
					                    "oidc_well_known_url": url,
 | 
				
			||||||
 | 
					                    "oidc_jwks_url": "",
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertFalse(serializer.is_valid())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_source_redirect(self):
 | 
					    def test_source_redirect(self):
 | 
				
			||||||
        """test redirect view"""
 | 
					        """test redirect view"""
 | 
				
			||||||
        self.client.get(
 | 
					        self.client.get(
 | 
				
			||||||
 | 
				
			|||||||
@ -1,22 +1,39 @@
 | 
				
			|||||||
"""Twitter OAuth Views"""
 | 
					"""Twitter OAuth Views"""
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.sources.oauth.clients.oauth2 import SESSION_OAUTH_PKCE
 | 
					from authentik.sources.oauth.clients.oauth2 import SESSION_KEY_OAUTH_PKCE
 | 
				
			||||||
from authentik.sources.oauth.types.azure_ad import AzureADClient
 | 
					from authentik.sources.oauth.types.azure_ad import AzureADClient
 | 
				
			||||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
 | 
					from authentik.sources.oauth.types.manager import MANAGER, SourceType
 | 
				
			||||||
from authentik.sources.oauth.views.callback import OAuthCallback
 | 
					from authentik.sources.oauth.views.callback import OAuthCallback
 | 
				
			||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
 | 
					from authentik.sources.oauth.views.redirect import OAuthRedirect
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TwitterClient(AzureADClient):
 | 
				
			||||||
 | 
					    """Twitter has similar quirks to Azure AD, and additionally requires Basic auth on
 | 
				
			||||||
 | 
					    the access token endpoint for some reason."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Twitter has the same quirk as azure and throws an error if the access token
 | 
				
			||||||
 | 
					    # is set via query parameter, so we re-use the azure client
 | 
				
			||||||
 | 
					    # see https://github.com/goauthentik/authentik/issues/1910
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
 | 
				
			||||||
 | 
					        return super().get_access_token(
 | 
				
			||||||
 | 
					            auth=(
 | 
				
			||||||
 | 
					                self.source.consumer_key,
 | 
				
			||||||
 | 
					                self.source.consumer_secret,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TwitterOAuthRedirect(OAuthRedirect):
 | 
					class TwitterOAuthRedirect(OAuthRedirect):
 | 
				
			||||||
    """Twitter OAuth2 Redirect"""
 | 
					    """Twitter OAuth2 Redirect"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_additional_parameters(self, source):  # pragma: no cover
 | 
					    def get_additional_parameters(self, source):  # pragma: no cover
 | 
				
			||||||
        self.request.session[SESSION_OAUTH_PKCE] = generate_id()
 | 
					        self.request.session[SESSION_KEY_OAUTH_PKCE] = generate_id()
 | 
				
			||||||
        return {
 | 
					        return {
 | 
				
			||||||
            "scope": ["users.read", "tweet.read"],
 | 
					            "scope": ["users.read", "tweet.read"],
 | 
				
			||||||
            "code_challenge": self.request.session[SESSION_OAUTH_PKCE],
 | 
					            "code_challenge": self.request.session[SESSION_KEY_OAUTH_PKCE],
 | 
				
			||||||
            "code_challenge_method": "plain",
 | 
					            "code_challenge_method": "plain",
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -24,10 +41,7 @@ class TwitterOAuthRedirect(OAuthRedirect):
 | 
				
			|||||||
class TwitterOAuthCallback(OAuthCallback):
 | 
					class TwitterOAuthCallback(OAuthCallback):
 | 
				
			||||||
    """Twitter OAuth2 Callback"""
 | 
					    """Twitter OAuth2 Callback"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Twitter has the same quirk as azure and throws an error if the access token
 | 
					    client_class = TwitterClient
 | 
				
			||||||
    # is set via query parameter, so we re-use the azure client
 | 
					 | 
				
			||||||
    # see https://github.com/goauthentik/authentik/issues/1910
 | 
					 | 
				
			||||||
    client_class = AzureADClient
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_user_id(self, info: dict[str, str]) -> str:
 | 
					    def get_user_id(self, info: dict[str, str]) -> str:
 | 
				
			||||||
        return info.get("data", {}).get("id", "")
 | 
					        return info.get("data", {}).get("id", "")
 | 
				
			||||||
 | 
				
			|||||||
@ -11,8 +11,6 @@ from authentik.lib.utils.http import get_http_session
 | 
				
			|||||||
from authentik.sources.plex.models import PlexSource, PlexSourceConnection
 | 
					from authentik.sources.plex.models import PlexSource, PlexSourceConnection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
SESSION_ID_KEY = "PLEX_ID"
 | 
					 | 
				
			||||||
SESSION_CODE_KEY = "PLEX_CODE"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PlexAuth:
 | 
					class PlexAuth:
 | 
				
			||||||
 | 
				
			|||||||
@ -19,7 +19,7 @@ from authentik.sources.saml.processors.constants import (
 | 
				
			|||||||
    SIGN_ALGORITHM_TRANSFORM_MAP,
 | 
					    SIGN_ALGORITHM_TRANSFORM_MAP,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SESSION_REQUEST_ID = "authentik_source_saml_request_id"
 | 
					SESSION_KEY_REQUEST_ID = "authentik/sources/saml/request_id"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RequestProcessor:
 | 
					class RequestProcessor:
 | 
				
			||||||
@ -38,7 +38,7 @@ class RequestProcessor:
 | 
				
			|||||||
        self.http_request = request
 | 
					        self.http_request = request
 | 
				
			||||||
        self.relay_state = relay_state
 | 
					        self.relay_state = relay_state
 | 
				
			||||||
        self.request_id = get_random_id()
 | 
					        self.request_id = get_random_id()
 | 
				
			||||||
        self.http_request.session[SESSION_REQUEST_ID] = self.request_id
 | 
					        self.http_request.session[SESSION_KEY_REQUEST_ID] = self.request_id
 | 
				
			||||||
        self.issue_instant = get_time_string()
 | 
					        self.issue_instant = get_time_string()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_issuer(self) -> Element:
 | 
					    def get_issuer(self) -> Element:
 | 
				
			||||||
 | 
				
			|||||||
@ -45,7 +45,7 @@ from authentik.sources.saml.processors.constants import (
 | 
				
			|||||||
    SAML_NAME_ID_FORMAT_WINDOWS,
 | 
					    SAML_NAME_ID_FORMAT_WINDOWS,
 | 
				
			||||||
    SAML_NAME_ID_FORMAT_X509,
 | 
					    SAML_NAME_ID_FORMAT_X509,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.sources.saml.processors.request import SESSION_REQUEST_ID
 | 
					from authentik.sources.saml.processors.request import SESSION_KEY_REQUEST_ID
 | 
				
			||||||
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
 | 
					from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
 | 
				
			||||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
 | 
					from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
 | 
				
			||||||
from authentik.stages.user_login.stage import BACKEND_INBUILT
 | 
					from authentik.stages.user_login.stage import BACKEND_INBUILT
 | 
				
			||||||
@ -119,11 +119,11 @@ class ResponseProcessor:
 | 
				
			|||||||
            seen_ids.append(self._root.attrib["ID"])
 | 
					            seen_ids.append(self._root.attrib["ID"])
 | 
				
			||||||
            cache.set(CACHE_SEEN_REQUEST_ID % self._source.pk, seen_ids)
 | 
					            cache.set(CACHE_SEEN_REQUEST_ID % self._source.pk, seen_ids)
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        if SESSION_REQUEST_ID not in request.session or "InResponseTo" not in self._root.attrib:
 | 
					        if SESSION_KEY_REQUEST_ID not in request.session or "InResponseTo" not in self._root.attrib:
 | 
				
			||||||
            raise MismatchedRequestID(
 | 
					            raise MismatchedRequestID(
 | 
				
			||||||
                "Missing InResponseTo and IdP-initiated Logins are not allowed"
 | 
					                "Missing InResponseTo and IdP-initiated Logins are not allowed"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        if request.session[SESSION_REQUEST_ID] != self._root.attrib["InResponseTo"]:
 | 
					        if request.session[SESSION_KEY_REQUEST_ID] != self._root.attrib["InResponseTo"]:
 | 
				
			||||||
            raise MismatchedRequestID("Mismatched request ID")
 | 
					            raise MismatchedRequestID("Mismatched request ID")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _handle_name_id_transient(self, request: HttpRequest) -> HttpResponse:
 | 
					    def _handle_name_id_transient(self, request: HttpRequest) -> HttpResponse:
 | 
				
			||||||
 | 
				
			|||||||
@ -4,6 +4,7 @@ from django.test import RequestFactory, TestCase
 | 
				
			|||||||
from lxml import etree  # nosec
 | 
					from lxml import etree  # nosec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
					from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
				
			||||||
 | 
					from authentik.lib.xml import lxml_from_string
 | 
				
			||||||
from authentik.sources.saml.models import SAMLSource
 | 
					from authentik.sources.saml.models import SAMLSource
 | 
				
			||||||
from authentik.sources.saml.processors.metadata import MetadataProcessor
 | 
					from authentik.sources.saml.processors.metadata import MetadataProcessor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -24,7 +25,7 @@ class TestMetadataProcessor(TestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        request = self.factory.get("/")
 | 
					        request = self.factory.get("/")
 | 
				
			||||||
        xml = MetadataProcessor(source, request).build_entity_descriptor()
 | 
					        xml = MetadataProcessor(source, request).build_entity_descriptor()
 | 
				
			||||||
        metadata = etree.fromstring(xml)  # nosec
 | 
					        metadata = lxml_from_string(xml)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        schema = etree.XMLSchema(etree.parse("xml/saml-schema-metadata-2.0.xsd"))  # nosec
 | 
					        schema = etree.XMLSchema(etree.parse("xml/saml-schema-metadata-2.0.xsd"))  # nosec
 | 
				
			||||||
        self.assertTrue(schema.validate(metadata))
 | 
					        self.assertTrue(schema.validate(metadata))
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,6 @@
 | 
				
			|||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from django.utils.timezone import now
 | 
					from django.utils.timezone import now
 | 
				
			||||||
from rest_framework.fields import CharField
 | 
					from rest_framework.fields import CharField
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.events.models import Event, EventAction
 | 
					from authentik.events.models import Event, EventAction
 | 
				
			||||||
from authentik.flows.challenge import (
 | 
					from authentik.flows.challenge import (
 | 
				
			||||||
@ -16,10 +15,8 @@ from authentik.flows.stage import ChallengeStageView
 | 
				
			|||||||
from authentik.flows.views.executor import InvalidStageError
 | 
					from authentik.flows.views.executor import InvalidStageError
 | 
				
			||||||
from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice
 | 
					from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					SESSION_KEY_DUO_USER_ID = "authentik/stages/authenticator_duo/user_id"
 | 
				
			||||||
 | 
					SESSION_KEY_DUO_ACTIVATION_CODE = "authentik/stages/authenticator_duo/activation_code"
 | 
				
			||||||
SESSION_KEY_DUO_USER_ID = "authentik_stages_authenticator_duo_user_id"
 | 
					 | 
				
			||||||
SESSION_KEY_DUO_ACTIVATION_CODE = "authentik_stages_authenticator_duo_activation_code"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthenticatorDuoChallenge(WithUserInfoChallenge):
 | 
					class AuthenticatorDuoChallenge(WithUserInfoChallenge):
 | 
				
			||||||
@ -69,7 +66,7 @@ class AuthenticatorDuoStageView(ChallengeStageView):
 | 
				
			|||||||
    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
					    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
				
			||||||
        user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
 | 
					        user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
 | 
				
			||||||
        if not user:
 | 
					        if not user:
 | 
				
			||||||
            LOGGER.debug("No pending user, continuing")
 | 
					            self.logger.debug("No pending user, continuing")
 | 
				
			||||||
            return self.executor.stage_ok()
 | 
					            return self.executor.stage_ok()
 | 
				
			||||||
        return super().get(request, *args, **kwargs)
 | 
					        return super().get(request, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -95,3 +92,7 @@ class AuthenticatorDuoStageView(ChallengeStageView):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return self.executor.stage_invalid("Device with Credential ID already exists.")
 | 
					            return self.executor.stage_invalid("Device with Credential ID already exists.")
 | 
				
			||||||
        return self.executor.stage_ok()
 | 
					        return self.executor.stage_ok()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def cleanup(self):
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_DUO_USER_ID)
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_DUO_ACTIVATION_CODE)
 | 
				
			||||||
 | 
				
			|||||||
@ -26,6 +26,7 @@ class AuthenticatorSMSStageSerializer(StageSerializer):
 | 
				
			|||||||
            "auth",
 | 
					            "auth",
 | 
				
			||||||
            "auth_password",
 | 
					            "auth_password",
 | 
				
			||||||
            "auth_type",
 | 
					            "auth_type",
 | 
				
			||||||
 | 
					            "verify_only",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,25 @@
 | 
				
			|||||||
 | 
					# Generated by Django 4.0.4 on 2022-05-24 19:08
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ("authentik_stages_authenticator_sms", "0003_smsdevice_last_used_on"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.AddField(
 | 
				
			||||||
 | 
					            model_name="authenticatorsmsstage",
 | 
				
			||||||
 | 
					            name="verify_only",
 | 
				
			||||||
 | 
					            field=models.BooleanField(
 | 
				
			||||||
 | 
					                default=False,
 | 
				
			||||||
 | 
					                help_text="When enabled, the Phone number is only used during enrollment to verify the users authenticity. Only a hash of the phone number is saved to ensure it is not re-used in the future.",
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AlterUniqueTogether(
 | 
				
			||||||
 | 
					            name="smsdevice",
 | 
				
			||||||
 | 
					            unique_together={("stage", "phone_number")},
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@ -1,4 +1,5 @@
 | 
				
			|||||||
"""OTP Time-based models"""
 | 
					"""SMS Authenticator models"""
 | 
				
			||||||
 | 
					from hashlib import sha256
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.contrib.auth import get_user_model
 | 
					from django.contrib.auth import get_user_model
 | 
				
			||||||
@ -46,6 +47,15 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage):
 | 
				
			|||||||
    auth_password = models.TextField(default="", blank=True)
 | 
					    auth_password = models.TextField(default="", blank=True)
 | 
				
			||||||
    auth_type = models.TextField(choices=SMSAuthTypes.choices, default=SMSAuthTypes.BASIC)
 | 
					    auth_type = models.TextField(choices=SMSAuthTypes.choices, default=SMSAuthTypes.BASIC)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    verify_only = models.BooleanField(
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        help_text=_(
 | 
				
			||||||
 | 
					            "When enabled, the Phone number is only used during enrollment to verify the "
 | 
				
			||||||
 | 
					            "users authenticity. Only a hash of the phone number is saved to ensure it is "
 | 
				
			||||||
 | 
					            "not re-used in the future."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def send(self, token: str, device: "SMSDevice"):
 | 
					    def send(self, token: str, device: "SMSDevice"):
 | 
				
			||||||
        """Send message via selected provider"""
 | 
					        """Send message via selected provider"""
 | 
				
			||||||
        if self.provider == SMSProviders.TWILIO:
 | 
					        if self.provider == SMSProviders.TWILIO:
 | 
				
			||||||
@ -158,6 +168,11 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage):
 | 
				
			|||||||
        verbose_name_plural = _("SMS Authenticator Setup Stages")
 | 
					        verbose_name_plural = _("SMS Authenticator Setup Stages")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def hash_phone_number(phone_number: str) -> str:
 | 
				
			||||||
 | 
					    """Hash phone number with prefix"""
 | 
				
			||||||
 | 
					    return "hash:" + sha256(phone_number.encode()).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SMSDevice(SideChannelDevice):
 | 
					class SMSDevice(SideChannelDevice):
 | 
				
			||||||
    """SMS Device"""
 | 
					    """SMS Device"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -170,6 +185,15 @@ class SMSDevice(SideChannelDevice):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    last_t = models.DateTimeField(auto_now=True)
 | 
					    last_t = models.DateTimeField(auto_now=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_hashed_number(self):
 | 
				
			||||||
 | 
					        """Set phone_number to hashed number"""
 | 
				
			||||||
 | 
					        self.phone_number = hash_phone_number(self.phone_number)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def is_hashed(self) -> bool:
 | 
				
			||||||
 | 
					        """Check if the phone number is hashed"""
 | 
				
			||||||
 | 
					        return self.phone_number.startswith("hash:")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def verify_token(self, token):
 | 
					    def verify_token(self, token):
 | 
				
			||||||
        valid = super().verify_token(token)
 | 
					        valid = super().verify_token(token)
 | 
				
			||||||
        if valid:
 | 
					        if valid:
 | 
				
			||||||
@ -182,3 +206,4 @@ class SMSDevice(SideChannelDevice):
 | 
				
			|||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        verbose_name = _("SMS Device")
 | 
					        verbose_name = _("SMS Device")
 | 
				
			||||||
        verbose_name_plural = _("SMS Devices")
 | 
					        verbose_name_plural = _("SMS Devices")
 | 
				
			||||||
 | 
					        unique_together = (("stage", "phone_number"),)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +1,12 @@
 | 
				
			|||||||
"""SMS Setup stage"""
 | 
					"""SMS Setup stage"""
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db.models import Q
 | 
				
			||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from django.http.request import QueryDict
 | 
					from django.http.request import QueryDict
 | 
				
			||||||
from django.utils.translation import gettext_lazy as _
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
from rest_framework.exceptions import ValidationError
 | 
					from rest_framework.exceptions import ValidationError
 | 
				
			||||||
from rest_framework.fields import BooleanField, CharField, IntegerField
 | 
					from rest_framework.fields import BooleanField, CharField, IntegerField
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.flows.challenge import (
 | 
					from authentik.flows.challenge import (
 | 
				
			||||||
    Challenge,
 | 
					    Challenge,
 | 
				
			||||||
@ -16,11 +16,14 @@ from authentik.flows.challenge import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
 | 
					from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
 | 
				
			||||||
from authentik.flows.stage import ChallengeStageView
 | 
					from authentik.flows.stage import ChallengeStageView
 | 
				
			||||||
from authentik.stages.authenticator_sms.models import AuthenticatorSMSStage, SMSDevice
 | 
					from authentik.stages.authenticator_sms.models import (
 | 
				
			||||||
 | 
					    AuthenticatorSMSStage,
 | 
				
			||||||
 | 
					    SMSDevice,
 | 
				
			||||||
 | 
					    hash_phone_number,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
 | 
					from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					SESSION_KEY_SMS_DEVICE = "authentik/stages/authenticator_sms/sms_device"
 | 
				
			||||||
SESSION_SMS_DEVICE = "sms_device"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthenticatorSMSChallenge(WithUserInfoChallenge):
 | 
					class AuthenticatorSMSChallenge(WithUserInfoChallenge):
 | 
				
			||||||
@ -47,6 +50,10 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse):
 | 
				
			|||||||
        stage: AuthenticatorSMSStage = self.device.stage
 | 
					        stage: AuthenticatorSMSStage = self.device.stage
 | 
				
			||||||
        if "code" not in attrs:
 | 
					        if "code" not in attrs:
 | 
				
			||||||
            self.device.phone_number = attrs["phone_number"]
 | 
					            self.device.phone_number = attrs["phone_number"]
 | 
				
			||||||
 | 
					            hashed_number = hash_phone_number(self.device.phone_number)
 | 
				
			||||||
 | 
					            query = Q(phone_number=hashed_number) | Q(phone_number=self.device.phone_number)
 | 
				
			||||||
 | 
					            if SMSDevice.objects.filter(query, stage=self.stage.executor.current_stage.pk).exists():
 | 
				
			||||||
 | 
					                raise ValidationError(_("Invalid phone number"))
 | 
				
			||||||
            # No code yet, but we have a phone number, so send a verification message
 | 
					            # No code yet, but we have a phone number, so send a verification message
 | 
				
			||||||
            stage.send(self.device.token, self.device)
 | 
					            stage.send(self.device.token, self.device)
 | 
				
			||||||
            return super().validate(attrs)
 | 
					            return super().validate(attrs)
 | 
				
			||||||
@ -64,11 +71,11 @@ class AuthenticatorSMSStageView(ChallengeStageView):
 | 
				
			|||||||
    def _has_phone_number(self) -> Optional[str]:
 | 
					    def _has_phone_number(self) -> Optional[str]:
 | 
				
			||||||
        context = self.executor.plan.context
 | 
					        context = self.executor.plan.context
 | 
				
			||||||
        if "phone" in context.get(PLAN_CONTEXT_PROMPT, {}):
 | 
					        if "phone" in context.get(PLAN_CONTEXT_PROMPT, {}):
 | 
				
			||||||
            LOGGER.debug("got phone number from plan context")
 | 
					            self.logger.debug("got phone number from plan context")
 | 
				
			||||||
            return context.get(PLAN_CONTEXT_PROMPT, {}).get("phone")
 | 
					            return context.get(PLAN_CONTEXT_PROMPT, {}).get("phone")
 | 
				
			||||||
        if SESSION_SMS_DEVICE in self.request.session:
 | 
					        if SESSION_KEY_SMS_DEVICE in self.request.session:
 | 
				
			||||||
            LOGGER.debug("got phone number from device in session")
 | 
					            self.logger.debug("got phone number from device in session")
 | 
				
			||||||
            device: SMSDevice = self.request.session[SESSION_SMS_DEVICE]
 | 
					            device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
 | 
				
			||||||
            if device.phone_number == "":
 | 
					            if device.phone_number == "":
 | 
				
			||||||
                return None
 | 
					                return None
 | 
				
			||||||
            return device.phone_number
 | 
					            return device.phone_number
 | 
				
			||||||
@ -84,13 +91,13 @@ class AuthenticatorSMSStageView(ChallengeStageView):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
 | 
					    def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
 | 
				
			||||||
        response = super().get_response_instance(data)
 | 
					        response = super().get_response_instance(data)
 | 
				
			||||||
        response.device = self.request.session[SESSION_SMS_DEVICE]
 | 
					        response.device = self.request.session[SESSION_KEY_SMS_DEVICE]
 | 
				
			||||||
        return response
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
					    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
				
			||||||
        user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
 | 
					        user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
 | 
				
			||||||
        if not user:
 | 
					        if not user:
 | 
				
			||||||
            LOGGER.debug("No pending user, continuing")
 | 
					            self.logger.debug("No pending user, continuing")
 | 
				
			||||||
            return self.executor.stage_ok()
 | 
					            return self.executor.stage_ok()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Currently, this stage only supports one device per user. If the user already
 | 
					        # Currently, this stage only supports one device per user. If the user already
 | 
				
			||||||
@ -100,19 +107,23 @@ class AuthenticatorSMSStageView(ChallengeStageView):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        stage: AuthenticatorSMSStage = self.executor.current_stage
 | 
					        stage: AuthenticatorSMSStage = self.executor.current_stage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if SESSION_SMS_DEVICE not in self.request.session:
 | 
					        if SESSION_KEY_SMS_DEVICE not in self.request.session:
 | 
				
			||||||
            device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device")
 | 
					            device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device")
 | 
				
			||||||
            device.generate_token(commit=False)
 | 
					            device.generate_token(commit=False)
 | 
				
			||||||
            if phone_number := self._has_phone_number():
 | 
					            if phone_number := self._has_phone_number():
 | 
				
			||||||
                device.phone_number = phone_number
 | 
					                device.phone_number = phone_number
 | 
				
			||||||
            self.request.session[SESSION_SMS_DEVICE] = device
 | 
					            self.request.session[SESSION_KEY_SMS_DEVICE] = device
 | 
				
			||||||
        return super().get(request, *args, **kwargs)
 | 
					        return super().get(request, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
 | 
					    def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
 | 
				
			||||||
        """SMS Token is validated by challenge"""
 | 
					        """SMS Token is validated by challenge"""
 | 
				
			||||||
        device: SMSDevice = self.request.session[SESSION_SMS_DEVICE]
 | 
					        device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
 | 
				
			||||||
        if not device.confirmed:
 | 
					        if not device.confirmed:
 | 
				
			||||||
            return self.challenge_invalid(response)
 | 
					            return self.challenge_invalid(response)
 | 
				
			||||||
 | 
					        stage: AuthenticatorSMSStage = self.executor.current_stage
 | 
				
			||||||
 | 
					        if stage.verify_only:
 | 
				
			||||||
 | 
					            self.logger.debug("Hashing number on device")
 | 
				
			||||||
 | 
					            device.set_hashed_number()
 | 
				
			||||||
        device.save()
 | 
					        device.save()
 | 
				
			||||||
        del self.request.session[SESSION_SMS_DEVICE]
 | 
					        del self.request.session[SESSION_KEY_SMS_DEVICE]
 | 
				
			||||||
        return self.executor.stage_ok()
 | 
					        return self.executor.stage_ok()
 | 
				
			||||||
 | 
				
			|||||||
@ -2,32 +2,31 @@
 | 
				
			|||||||
from unittest.mock import MagicMock, patch
 | 
					from unittest.mock import MagicMock, patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from rest_framework.test import APITestCase
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
				
			||||||
from authentik.flows.challenge import ChallengeTypes
 | 
					from authentik.flows.models import FlowStageBinding
 | 
				
			||||||
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
 | 
					from authentik.flows.tests import FlowTestCase
 | 
				
			||||||
from authentik.stages.authenticator_sms.models import AuthenticatorSMSStage, SMSProviders
 | 
					from authentik.stages.authenticator_sms.models import (
 | 
				
			||||||
from authentik.stages.authenticator_sms.stage import SESSION_SMS_DEVICE
 | 
					    AuthenticatorSMSStage,
 | 
				
			||||||
 | 
					    SMSDevice,
 | 
				
			||||||
 | 
					    SMSProviders,
 | 
				
			||||||
 | 
					    hash_phone_number,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthenticatorSMSStageTests(APITestCase):
 | 
					class AuthenticatorSMSStageTests(FlowTestCase):
 | 
				
			||||||
    """Test SMS API"""
 | 
					    """Test SMS API"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self) -> None:
 | 
					    def setUp(self) -> None:
 | 
				
			||||||
        super().setUp()
 | 
					        super().setUp()
 | 
				
			||||||
        self.flow = Flow.objects.create(
 | 
					        self.flow = create_test_flow()
 | 
				
			||||||
            name="foo",
 | 
					        self.stage: AuthenticatorSMSStage = AuthenticatorSMSStage.objects.create(
 | 
				
			||||||
            slug="foo",
 | 
					 | 
				
			||||||
            designation=FlowDesignation.STAGE_CONFIGURATION,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.stage = AuthenticatorSMSStage.objects.create(
 | 
					 | 
				
			||||||
            name="foo",
 | 
					            name="foo",
 | 
				
			||||||
            provider=SMSProviders.TWILIO,
 | 
					            provider=SMSProviders.TWILIO,
 | 
				
			||||||
            configure_flow=self.flow,
 | 
					            configure_flow=self.flow,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=0)
 | 
					        FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=0)
 | 
				
			||||||
        self.user = User.objects.create(username="foo")
 | 
					        self.user = create_test_admin_user()
 | 
				
			||||||
        self.client.force_login(self.user)
 | 
					        self.client.force_login(self.user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_stage_no_prefill(self):
 | 
					    def test_stage_no_prefill(self):
 | 
				
			||||||
@ -38,27 +37,29 @@ class AuthenticatorSMSStageTests(APITestCase):
 | 
				
			|||||||
        response = self.client.get(
 | 
					        response = self.client.get(
 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertJSONEqual(
 | 
					        self.assertStageResponse(
 | 
				
			||||||
            response.content,
 | 
					            response,
 | 
				
			||||||
            {
 | 
					            self.flow,
 | 
				
			||||||
                "component": "ak-stage-authenticator-sms",
 | 
					            self.user,
 | 
				
			||||||
                "flow_info": {
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
                    "background": self.flow.background_url,
 | 
					            phone_number_required=True,
 | 
				
			||||||
                    "cancel_url": reverse("authentik_flows:cancel"),
 | 
					 | 
				
			||||||
                    "title": "",
 | 
					 | 
				
			||||||
                    "layout": "stacked",
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                "pending_user": "foo",
 | 
					 | 
				
			||||||
                "pending_user_avatar": "/static/dist/assets/images/user_default.png",
 | 
					 | 
				
			||||||
                "phone_number_required": True,
 | 
					 | 
				
			||||||
                "type": ChallengeTypes.NATIVE.value,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_stage_submit(self):
 | 
					    def test_stage_submit(self):
 | 
				
			||||||
        """test stage (submit)"""
 | 
					        """test stage (submit)"""
 | 
				
			||||||
        # Prepares session etc
 | 
					        self.client.get(
 | 
				
			||||||
        self.test_stage_no_prefill()
 | 
					            reverse("authentik_flows:configure", kwargs={"stage_uuid": self.stage.stage_uuid}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.get(
 | 
				
			||||||
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            phone_number_required=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        sms_send_mock = MagicMock()
 | 
					        sms_send_mock = MagicMock()
 | 
				
			||||||
        with patch(
 | 
					        with patch(
 | 
				
			||||||
            "authentik.stages.authenticator_sms.models.AuthenticatorSMSStage.send",
 | 
					            "authentik.stages.authenticator_sms.models.AuthenticatorSMSStage.send",
 | 
				
			||||||
@ -70,23 +71,156 @@ class AuthenticatorSMSStageTests(APITestCase):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
            self.assertEqual(response.status_code, 200)
 | 
					            self.assertEqual(response.status_code, 200)
 | 
				
			||||||
            sms_send_mock.assert_called_once()
 | 
					            sms_send_mock.assert_called_once()
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            response_errors={},
 | 
				
			||||||
 | 
					            phone_number_required=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_stage_submit_full(self):
 | 
					    def test_stage_submit_full(self):
 | 
				
			||||||
        """test stage (submit)"""
 | 
					        """test stage (submit)"""
 | 
				
			||||||
        # Prepares session etc
 | 
					        self.client.get(
 | 
				
			||||||
        self.test_stage_submit()
 | 
					            reverse("authentik_flows:configure", kwargs={"stage_uuid": self.stage.stage_uuid}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.get(
 | 
				
			||||||
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            phone_number_required=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        sms_send_mock = MagicMock()
 | 
					        sms_send_mock = MagicMock()
 | 
				
			||||||
        with patch(
 | 
					        with patch(
 | 
				
			||||||
            "authentik.stages.authenticator_sms.models.AuthenticatorSMSStage.send",
 | 
					            "authentik.stages.authenticator_sms.models.AuthenticatorSMSStage.send",
 | 
				
			||||||
            sms_send_mock,
 | 
					            sms_send_mock,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            response = self.client.post(
 | 
				
			||||||
 | 
					                reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					                data={"component": "ak-stage-authenticator-sms", "phone_number": "foo"},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					            sms_send_mock.assert_called_once()
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            response_errors={},
 | 
				
			||||||
 | 
					            phone_number_required=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        with patch(
 | 
				
			||||||
 | 
					            "authentik.stages.authenticator_sms.models.SMSDevice.verify_token",
 | 
				
			||||||
 | 
					            MagicMock(return_value=True),
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            response = self.client.post(
 | 
					            response = self.client.post(
 | 
				
			||||||
                reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					                reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
                    "component": "ak-stage-authenticator-sms",
 | 
					                    "component": "ak-stage-authenticator-sms",
 | 
				
			||||||
                    "phone_number": "foo",
 | 
					                    "phone_number": "foo",
 | 
				
			||||||
                    "code": int(self.client.session[SESSION_SMS_DEVICE].token),
 | 
					                    "code": "123456",
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					        self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stage_hash(self):
 | 
				
			||||||
 | 
					        """test stage (verify_only)"""
 | 
				
			||||||
 | 
					        self.stage.verify_only = True
 | 
				
			||||||
 | 
					        self.stage.save()
 | 
				
			||||||
 | 
					        self.client.get(
 | 
				
			||||||
 | 
					            reverse("authentik_flows:configure", kwargs={"stage_uuid": self.stage.stage_uuid}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.get(
 | 
				
			||||||
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            phone_number_required=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        sms_send_mock = MagicMock()
 | 
				
			||||||
 | 
					        with patch(
 | 
				
			||||||
 | 
					            "authentik.stages.authenticator_sms.models.AuthenticatorSMSStage.send",
 | 
				
			||||||
 | 
					            sms_send_mock,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            response = self.client.post(
 | 
				
			||||||
 | 
					                reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					                data={"component": "ak-stage-authenticator-sms", "phone_number": "foo"},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            self.assertEqual(response.status_code, 200)
 | 
					            self.assertEqual(response.status_code, 200)
 | 
				
			||||||
            sms_send_mock.assert_not_called()
 | 
					            sms_send_mock.assert_called_once()
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            response_errors={},
 | 
				
			||||||
 | 
					            phone_number_required=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        with patch(
 | 
				
			||||||
 | 
					            "authentik.stages.authenticator_sms.models.SMSDevice.verify_token",
 | 
				
			||||||
 | 
					            MagicMock(return_value=True),
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            response = self.client.post(
 | 
				
			||||||
 | 
					                reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					                data={
 | 
				
			||||||
 | 
					                    "component": "ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					                    "phone_number": "foo",
 | 
				
			||||||
 | 
					                    "code": "123456",
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					        self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
 | 
				
			||||||
 | 
					        device: SMSDevice = SMSDevice.objects.filter(user=self.user).first()
 | 
				
			||||||
 | 
					        self.assertTrue(device.is_hashed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stage_hash_twice(self):
 | 
				
			||||||
 | 
					        """test stage (hash + duplicate)"""
 | 
				
			||||||
 | 
					        SMSDevice.objects.create(
 | 
				
			||||||
 | 
					            user=create_test_admin_user(),
 | 
				
			||||||
 | 
					            stage=self.stage,
 | 
				
			||||||
 | 
					            phone_number=hash_phone_number("foo"),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.stage.verify_only = True
 | 
				
			||||||
 | 
					        self.stage.save()
 | 
				
			||||||
 | 
					        self.client.get(
 | 
				
			||||||
 | 
					            reverse("authentik_flows:configure", kwargs={"stage_uuid": self.stage.stage_uuid}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response = self.client.get(
 | 
				
			||||||
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            phone_number_required=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        sms_send_mock = MagicMock()
 | 
				
			||||||
 | 
					        with patch(
 | 
				
			||||||
 | 
					            "authentik.stages.authenticator_sms.models.AuthenticatorSMSStage.send",
 | 
				
			||||||
 | 
					            sms_send_mock,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            response = self.client.post(
 | 
				
			||||||
 | 
					                reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
 | 
					                data={"component": "ak-stage-authenticator-sms", "phone_number": "foo"},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					        self.assertStageResponse(
 | 
				
			||||||
 | 
					            response,
 | 
				
			||||||
 | 
					            self.flow,
 | 
				
			||||||
 | 
					            self.user,
 | 
				
			||||||
 | 
					            component="ak-stage-authenticator-sms",
 | 
				
			||||||
 | 
					            response_errors={
 | 
				
			||||||
 | 
					                "non_field_errors": [{"code": "invalid", "string": "Invalid phone number"}]
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            phone_number_required=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -2,14 +2,11 @@
 | 
				
			|||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from django_otp.plugins.otp_static.models import StaticDevice, StaticToken
 | 
					from django_otp.plugins.otp_static.models import StaticDevice, StaticToken
 | 
				
			||||||
from rest_framework.fields import CharField, ListField
 | 
					from rest_framework.fields import CharField, ListField
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge
 | 
					from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge
 | 
				
			||||||
from authentik.flows.stage import ChallengeStageView
 | 
					from authentik.flows.stage import ChallengeStageView
 | 
				
			||||||
from authentik.stages.authenticator_static.models import AuthenticatorStaticStage
 | 
					from authentik.stages.authenticator_static.models import AuthenticatorStaticStage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthenticatorStaticChallenge(WithUserInfoChallenge):
 | 
					class AuthenticatorStaticChallenge(WithUserInfoChallenge):
 | 
				
			||||||
    """Static authenticator challenge"""
 | 
					    """Static authenticator challenge"""
 | 
				
			||||||
@ -42,7 +39,7 @@ class AuthenticatorStaticStageView(ChallengeStageView):
 | 
				
			|||||||
    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
					    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
				
			||||||
        user = self.get_pending_user()
 | 
					        user = self.get_pending_user()
 | 
				
			||||||
        if not user.is_authenticated:
 | 
					        if not user.is_authenticated:
 | 
				
			||||||
            LOGGER.debug("No pending user, continuing")
 | 
					            self.logger.debug("No pending user, continuing")
 | 
				
			||||||
            return self.executor.stage_ok()
 | 
					            return self.executor.stage_ok()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stage: AuthenticatorStaticStage = self.executor.current_stage
 | 
					        stage: AuthenticatorStaticStage = self.executor.current_stage
 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,6 @@ from django.utils.translation import gettext_lazy as _
 | 
				
			|||||||
from django_otp.plugins.otp_totp.models import TOTPDevice
 | 
					from django_otp.plugins.otp_totp.models import TOTPDevice
 | 
				
			||||||
from rest_framework.fields import CharField, IntegerField
 | 
					from rest_framework.fields import CharField, IntegerField
 | 
				
			||||||
from rest_framework.serializers import ValidationError
 | 
					from rest_framework.serializers import ValidationError
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.flows.challenge import (
 | 
					from authentik.flows.challenge import (
 | 
				
			||||||
    Challenge,
 | 
					    Challenge,
 | 
				
			||||||
@ -18,8 +17,6 @@ from authentik.flows.stage import ChallengeStageView
 | 
				
			|||||||
from authentik.stages.authenticator_totp.models import AuthenticatorTOTPStage
 | 
					from authentik.stages.authenticator_totp.models import AuthenticatorTOTPStage
 | 
				
			||||||
from authentik.stages.authenticator_totp.settings import OTP_TOTP_ISSUER
 | 
					from authentik.stages.authenticator_totp.settings import OTP_TOTP_ISSUER
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthenticatorTOTPChallenge(WithUserInfoChallenge):
 | 
					class AuthenticatorTOTPChallenge(WithUserInfoChallenge):
 | 
				
			||||||
    """TOTP Setup challenge"""
 | 
					    """TOTP Setup challenge"""
 | 
				
			||||||
@ -72,7 +69,7 @@ class AuthenticatorTOTPStageView(ChallengeStageView):
 | 
				
			|||||||
    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
					    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
				
			||||||
        user = self.get_pending_user()
 | 
					        user = self.get_pending_user()
 | 
				
			||||||
        if not user.is_authenticated:
 | 
					        if not user.is_authenticated:
 | 
				
			||||||
            LOGGER.debug("No pending user, continuing")
 | 
					            self.logger.debug("No pending user, continuing")
 | 
				
			||||||
            return self.executor.stage_ok()
 | 
					            return self.executor.stage_ok()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stage: AuthenticatorTOTPStage = self.executor.current_stage
 | 
					        stage: AuthenticatorTOTPStage = self.executor.current_stage
 | 
				
			||||||
 | 
				
			|||||||
@ -18,10 +18,14 @@ from webauthn.helpers.structs import AuthenticationCredential
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from authentik.core.api.utils import PassiveSerializer
 | 
					from authentik.core.api.utils import PassiveSerializer
 | 
				
			||||||
from authentik.core.models import User
 | 
					from authentik.core.models import User
 | 
				
			||||||
 | 
					from authentik.core.signals import login_failed
 | 
				
			||||||
 | 
					from authentik.flows.stage import StageView
 | 
				
			||||||
from authentik.lib.utils.http import get_client_ip
 | 
					from authentik.lib.utils.http import get_client_ip
 | 
				
			||||||
from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice
 | 
					from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice
 | 
				
			||||||
from authentik.stages.authenticator_sms.models import SMSDevice
 | 
					from authentik.stages.authenticator_sms.models import SMSDevice
 | 
				
			||||||
 | 
					from authentik.stages.authenticator_validate.models import DeviceClasses
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
 | 
					from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
 | 
				
			||||||
 | 
					from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
 | 
					from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
@ -43,23 +47,23 @@ def get_challenge_for_device(request: HttpRequest, device: Device) -> dict:
 | 
				
			|||||||
    return {}
 | 
					    return {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_webauthn_challenge_userless(request: HttpRequest) -> dict:
 | 
					def get_webauthn_challenge_without_user(request: HttpRequest) -> dict:
 | 
				
			||||||
    """Same as `get_webauthn_challenge`, but allows any client device. We can then later check
 | 
					    """Same as `get_webauthn_challenge`, but allows any client device. We can then later check
 | 
				
			||||||
    who the device belongs to."""
 | 
					    who the device belongs to."""
 | 
				
			||||||
    request.session.pop("challenge", None)
 | 
					    request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
 | 
				
			||||||
    authentication_options = generate_authentication_options(
 | 
					    authentication_options = generate_authentication_options(
 | 
				
			||||||
        rp_id=get_rp_id(request),
 | 
					        rp_id=get_rp_id(request),
 | 
				
			||||||
        allow_credentials=[],
 | 
					        allow_credentials=[],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    request.session["challenge"] = authentication_options.challenge
 | 
					    request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return loads(options_to_json(authentication_options))
 | 
					    return loads(options_to_json(authentication_options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_webauthn_challenge(request: HttpRequest, device: Optional[WebAuthnDevice] = None) -> dict:
 | 
					def get_webauthn_challenge(request: HttpRequest, device: Optional[WebAuthnDevice] = None) -> dict:
 | 
				
			||||||
    """Send the client a challenge that we'll check later"""
 | 
					    """Send the client a challenge that we'll check later"""
 | 
				
			||||||
    request.session.pop("challenge", None)
 | 
					    request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    allowed_credentials = []
 | 
					    allowed_credentials = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -74,7 +78,7 @@ def get_webauthn_challenge(request: HttpRequest, device: Optional[WebAuthnDevice
 | 
				
			|||||||
        allow_credentials=allowed_credentials,
 | 
					        allow_credentials=allowed_credentials,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    request.session["challenge"] = authentication_options.challenge
 | 
					    request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return loads(options_to_json(authentication_options))
 | 
					    return loads(options_to_json(authentication_options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -91,24 +95,32 @@ def select_challenge_sms(request: HttpRequest, device: SMSDevice):
 | 
				
			|||||||
    device.stage.send(device.token, device)
 | 
					    device.stage.send(device.token, device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def validate_challenge_code(code: str, request: HttpRequest, user: User) -> str:
 | 
					def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Device:
 | 
				
			||||||
    """Validate code-based challenges. We test against every device, on purpose, as
 | 
					    """Validate code-based challenges. We test against every device, on purpose, as
 | 
				
			||||||
    the user mustn't choose between totp and static devices."""
 | 
					    the user mustn't choose between totp and static devices."""
 | 
				
			||||||
    device = match_token(user, code)
 | 
					    device = match_token(user, code)
 | 
				
			||||||
    if not device:
 | 
					    if not device:
 | 
				
			||||||
 | 
					        login_failed.send(
 | 
				
			||||||
 | 
					            sender=__name__,
 | 
				
			||||||
 | 
					            credentials={"username": user.username},
 | 
				
			||||||
 | 
					            request=stage_view.request,
 | 
				
			||||||
 | 
					            stage=stage_view.executor.current_stage,
 | 
				
			||||||
 | 
					            device_class=DeviceClasses.TOTP.value,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        raise ValidationError(_("Invalid Token"))
 | 
					        raise ValidationError(_("Invalid Token"))
 | 
				
			||||||
    return code
 | 
					    return device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# pylint: disable=unused-argument
 | 
					# pylint: disable=unused-argument
 | 
				
			||||||
def validate_challenge_webauthn(data: dict, request: HttpRequest, user: User) -> Device:
 | 
					def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device:
 | 
				
			||||||
    """Validate WebAuthn Challenge"""
 | 
					    """Validate WebAuthn Challenge"""
 | 
				
			||||||
    challenge = request.session.get("challenge")
 | 
					    request = stage_view.request
 | 
				
			||||||
 | 
					    challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE)
 | 
				
			||||||
    credential_id = data.get("id")
 | 
					    credential_id = data.get("id")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    device = WebAuthnDevice.objects.filter(credential_id=credential_id).first()
 | 
					    device = WebAuthnDevice.objects.filter(credential_id=credential_id).first()
 | 
				
			||||||
    if not device:
 | 
					    if not device:
 | 
				
			||||||
        raise ValidationError("Device does not exist.")
 | 
					        raise Http404()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        authentication_verification = verify_authentication_response(
 | 
					        authentication_verification = verify_authentication_response(
 | 
				
			||||||
@ -120,16 +132,23 @@ def validate_challenge_webauthn(data: dict, request: HttpRequest, user: User) ->
 | 
				
			|||||||
            credential_current_sign_count=device.sign_count,
 | 
					            credential_current_sign_count=device.sign_count,
 | 
				
			||||||
            require_user_verification=False,
 | 
					            require_user_verification=False,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
    except InvalidAuthenticationResponse as exc:
 | 
					    except InvalidAuthenticationResponse as exc:
 | 
				
			||||||
        LOGGER.warning("Assertion failed", exc=exc)
 | 
					        LOGGER.warning("Assertion failed", exc=exc)
 | 
				
			||||||
 | 
					        login_failed.send(
 | 
				
			||||||
 | 
					            sender=__name__,
 | 
				
			||||||
 | 
					            credentials={"username": user.username},
 | 
				
			||||||
 | 
					            request=stage_view.request,
 | 
				
			||||||
 | 
					            stage=stage_view.executor.current_stage,
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            device_class=DeviceClasses.WEBAUTHN.value,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        raise ValidationError("Assertion failed") from exc
 | 
					        raise ValidationError("Assertion failed") from exc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    device.set_sign_count(authentication_verification.new_sign_count)
 | 
					    device.set_sign_count(authentication_verification.new_sign_count)
 | 
				
			||||||
    return device
 | 
					    return device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def validate_challenge_duo(device_pk: int, request: HttpRequest, user: User) -> int:
 | 
					def validate_challenge_duo(device_pk: int, stage_view: StageView, user: User) -> Device:
 | 
				
			||||||
    """Duo authentication"""
 | 
					    """Duo authentication"""
 | 
				
			||||||
    device = get_object_or_404(DuoDevice, pk=device_pk)
 | 
					    device = get_object_or_404(DuoDevice, pk=device_pk)
 | 
				
			||||||
    if device.user != user:
 | 
					    if device.user != user:
 | 
				
			||||||
@ -139,13 +158,20 @@ def validate_challenge_duo(device_pk: int, request: HttpRequest, user: User) ->
 | 
				
			|||||||
    response = stage.client.auth(
 | 
					    response = stage.client.auth(
 | 
				
			||||||
        "auto",
 | 
					        "auto",
 | 
				
			||||||
        user_id=device.duo_user_id,
 | 
					        user_id=device.duo_user_id,
 | 
				
			||||||
        ipaddr=get_client_ip(request),
 | 
					        ipaddr=get_client_ip(stage_view.request),
 | 
				
			||||||
        type="authentik Login request",
 | 
					        type="authentik Login request",
 | 
				
			||||||
        display_username=user.username,
 | 
					        display_username=user.username,
 | 
				
			||||||
        device="auto",
 | 
					        device="auto",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    # {'result': 'allow', 'status': 'allow', 'status_msg': 'Success. Logging you in...'}
 | 
					    # {'result': 'allow', 'status': 'allow', 'status_msg': 'Success. Logging you in...'}
 | 
				
			||||||
    if response["result"] == "deny":
 | 
					    if response["result"] == "deny":
 | 
				
			||||||
 | 
					        login_failed.send(
 | 
				
			||||||
 | 
					            sender=__name__,
 | 
				
			||||||
 | 
					            credentials={"username": user.username},
 | 
				
			||||||
 | 
					            request=stage_view.request,
 | 
				
			||||||
 | 
					            stage=stage_view.executor.current_stage,
 | 
				
			||||||
 | 
					            device_class=DeviceClasses.DUO.value,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        raise ValidationError("Duo denied access")
 | 
					        raise ValidationError("Duo denied access")
 | 
				
			||||||
    device.save()
 | 
					    device.save()
 | 
				
			||||||
    return device_pk
 | 
					    return device
 | 
				
			||||||
 | 
				
			|||||||
@ -14,7 +14,7 @@ class DeviceClasses(models.TextChoices):
 | 
				
			|||||||
    """Device classes this stage can validate"""
 | 
					    """Device classes this stage can validate"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # device class must match Device's class name so StaticDevice -> static
 | 
					    # device class must match Device's class name so StaticDevice -> static
 | 
				
			||||||
    STATIC = "static"
 | 
					    STATIC = "static", _("Static")
 | 
				
			||||||
    TOTP = "totp", _("TOTP")
 | 
					    TOTP = "totp", _("TOTP")
 | 
				
			||||||
    WEBAUTHN = "webauthn", _("WebAuthn")
 | 
					    WEBAUTHN = "webauthn", _("WebAuthn")
 | 
				
			||||||
    DUO = "duo", _("Duo")
 | 
					    DUO = "duo", _("Duo")
 | 
				
			||||||
 | 
				
			|||||||
@ -1,10 +1,13 @@
 | 
				
			|||||||
"""Authenticator Validation"""
 | 
					"""Authenticator Validation"""
 | 
				
			||||||
from datetime import timezone
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					from hashlib import sha256
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.conf import settings
 | 
				
			||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from django.utils.timezone import datetime, now
 | 
					 | 
				
			||||||
from django_otp import devices_for_user
 | 
					from django_otp import devices_for_user
 | 
				
			||||||
from django_otp.models import Device
 | 
					from django_otp.models import Device
 | 
				
			||||||
 | 
					from jwt import PyJWTError, decode, encode
 | 
				
			||||||
from rest_framework.fields import CharField, IntegerField, JSONField, ListField, UUIDField
 | 
					from rest_framework.fields import CharField, IntegerField, JSONField, ListField, UUIDField
 | 
				
			||||||
from rest_framework.serializers import ValidationError
 | 
					from rest_framework.serializers import ValidationError
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
@ -23,7 +26,7 @@ from authentik.stages.authenticator_sms.models import SMSDevice
 | 
				
			|||||||
from authentik.stages.authenticator_validate.challenge import (
 | 
					from authentik.stages.authenticator_validate.challenge import (
 | 
				
			||||||
    DeviceChallenge,
 | 
					    DeviceChallenge,
 | 
				
			||||||
    get_challenge_for_device,
 | 
					    get_challenge_for_device,
 | 
				
			||||||
    get_webauthn_challenge_userless,
 | 
					    get_webauthn_challenge_without_user,
 | 
				
			||||||
    select_challenge,
 | 
					    select_challenge,
 | 
				
			||||||
    validate_challenge_code,
 | 
					    validate_challenge_code,
 | 
				
			||||||
    validate_challenge_duo,
 | 
					    validate_challenge_duo,
 | 
				
			||||||
@ -34,9 +37,12 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
 | 
				
			|||||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
 | 
					from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
SESSION_STAGES = "goauthentik.io/stages/authenticator_validate/stages"
 | 
					
 | 
				
			||||||
SESSION_SELECTED_STAGE = "goauthentik.io/stages/authenticator_validate/selected_stage"
 | 
					COOKIE_NAME_MFA = "authentik_mfa"
 | 
				
			||||||
SESSION_DEVICE_CHALLENGES = "goauthentik.io/stages/authenticator_validate/device_challenges"
 | 
					
 | 
				
			||||||
 | 
					SESSION_KEY_STAGES = "authentik/stages/authenticator_validate/stages"
 | 
				
			||||||
 | 
					SESSION_KEY_SELECTED_STAGE = "authentik/stages/authenticator_validate/selected_stage"
 | 
				
			||||||
 | 
					SESSION_KEY_DEVICE_CHALLENGES = "authentik/stages/authenticator_validate/device_challenges"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SelectableStageSerializer(PassiveSerializer):
 | 
					class SelectableStageSerializer(PassiveSerializer):
 | 
				
			||||||
@ -59,6 +65,8 @@ class AuthenticatorValidationChallenge(WithUserInfoChallenge):
 | 
				
			|||||||
class AuthenticatorValidationChallengeResponse(ChallengeResponse):
 | 
					class AuthenticatorValidationChallengeResponse(ChallengeResponse):
 | 
				
			||||||
    """Challenge used for Code-based and WebAuthn authenticators"""
 | 
					    """Challenge used for Code-based and WebAuthn authenticators"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device: Optional[Device]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    selected_challenge = DeviceChallenge(required=False)
 | 
					    selected_challenge = DeviceChallenge(required=False)
 | 
				
			||||||
    selected_stage = CharField(required=False)
 | 
					    selected_stage = CharField(required=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -68,33 +76,38 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
 | 
				
			|||||||
    component = CharField(default="ak-stage-authenticator-validate")
 | 
					    component = CharField(default="ak-stage-authenticator-validate")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _challenge_allowed(self, classes: list):
 | 
					    def _challenge_allowed(self, classes: list):
 | 
				
			||||||
        device_challenges: list[dict] = self.stage.request.session.get(SESSION_DEVICE_CHALLENGES)
 | 
					        device_challenges: list[dict] = self.stage.request.session.get(
 | 
				
			||||||
 | 
					            SESSION_KEY_DEVICE_CHALLENGES, []
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        if not any(x["device_class"] in classes for x in device_challenges):
 | 
					        if not any(x["device_class"] in classes for x in device_challenges):
 | 
				
			||||||
            raise ValidationError("No compatible device class allowed")
 | 
					            raise ValidationError("No compatible device class allowed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate_code(self, code: str) -> str:
 | 
					    def validate_code(self, code: str) -> str:
 | 
				
			||||||
        """Validate code-based response, raise error if code isn't allowed"""
 | 
					        """Validate code-based response, raise error if code isn't allowed"""
 | 
				
			||||||
        self._challenge_allowed([DeviceClasses.TOTP, DeviceClasses.STATIC, DeviceClasses.SMS])
 | 
					        self._challenge_allowed([DeviceClasses.TOTP, DeviceClasses.STATIC, DeviceClasses.SMS])
 | 
				
			||||||
        return validate_challenge_code(code, self.stage.request, self.stage.get_pending_user())
 | 
					        self.device = validate_challenge_code(code, self.stage, self.stage.get_pending_user())
 | 
				
			||||||
 | 
					        return code
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate_webauthn(self, webauthn: dict) -> dict:
 | 
					    def validate_webauthn(self, webauthn: dict) -> dict:
 | 
				
			||||||
        """Validate webauthn response, raise error if webauthn wasn't allowed
 | 
					        """Validate webauthn response, raise error if webauthn wasn't allowed
 | 
				
			||||||
        or response is invalid"""
 | 
					        or response is invalid"""
 | 
				
			||||||
        self._challenge_allowed([DeviceClasses.WEBAUTHN])
 | 
					        self._challenge_allowed([DeviceClasses.WEBAUTHN])
 | 
				
			||||||
        return validate_challenge_webauthn(
 | 
					        self.device = validate_challenge_webauthn(
 | 
				
			||||||
            webauthn, self.stage.request, self.stage.get_pending_user()
 | 
					            webauthn, self.stage, self.stage.get_pending_user()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        return webauthn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate_duo(self, duo: int) -> int:
 | 
					    def validate_duo(self, duo: int) -> int:
 | 
				
			||||||
        """Initiate Duo authentication"""
 | 
					        """Initiate Duo authentication"""
 | 
				
			||||||
        self._challenge_allowed([DeviceClasses.DUO])
 | 
					        self._challenge_allowed([DeviceClasses.DUO])
 | 
				
			||||||
        return validate_challenge_duo(duo, self.stage.request, self.stage.get_pending_user())
 | 
					        self.device = validate_challenge_duo(duo, self.stage, self.stage.get_pending_user())
 | 
				
			||||||
 | 
					        return duo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate_selected_challenge(self, challenge: dict) -> dict:
 | 
					    def validate_selected_challenge(self, challenge: dict) -> dict:
 | 
				
			||||||
        """Check which challenge the user has selected. Actual logic only used for SMS stage."""
 | 
					        """Check which challenge the user has selected. Actual logic only used for SMS stage."""
 | 
				
			||||||
        # First check if the challenge is valid
 | 
					        # First check if the challenge is valid
 | 
				
			||||||
        allowed = False
 | 
					        allowed = False
 | 
				
			||||||
        for device_challenge in self.stage.request.session.get(SESSION_DEVICE_CHALLENGES):
 | 
					        for device_challenge in self.stage.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []):
 | 
				
			||||||
            if device_challenge.get("device_class", "") == challenge.get(
 | 
					            if device_challenge.get("device_class", "") == challenge.get(
 | 
				
			||||||
                "device_class", ""
 | 
					                "device_class", ""
 | 
				
			||||||
            ) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""):
 | 
					            ) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""):
 | 
				
			||||||
@ -112,11 +125,11 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def validate_selected_stage(self, stage_pk: str) -> str:
 | 
					    def validate_selected_stage(self, stage_pk: str) -> str:
 | 
				
			||||||
        """Check that the selected stage is valid"""
 | 
					        """Check that the selected stage is valid"""
 | 
				
			||||||
        stages = self.stage.request.session.get(SESSION_STAGES, [])
 | 
					        stages = self.stage.request.session.get(SESSION_KEY_STAGES, [])
 | 
				
			||||||
        if not any(str(stage.pk) == stage_pk for stage in stages):
 | 
					        if not any(str(stage.pk) == stage_pk for stage in stages):
 | 
				
			||||||
            raise ValidationError("Selected stage is invalid")
 | 
					            raise ValidationError("Selected stage is invalid")
 | 
				
			||||||
        LOGGER.debug("Setting selected stage to ", stage=stage_pk)
 | 
					        LOGGER.debug("Setting selected stage to ", stage=stage_pk)
 | 
				
			||||||
        self.stage.request.session[SESSION_SELECTED_STAGE] = stage_pk
 | 
					        self.stage.request.session[SESSION_KEY_SELECTED_STAGE] = stage_pk
 | 
				
			||||||
        return stage_pk
 | 
					        return stage_pk
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate(self, attrs: dict):
 | 
					    def validate(self, attrs: dict):
 | 
				
			||||||
@ -127,15 +140,6 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
 | 
				
			|||||||
        return attrs
 | 
					        return attrs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_device_last_usage(device: Device) -> datetime:
 | 
					 | 
				
			||||||
    """Get a datetime object from last_t"""
 | 
					 | 
				
			||||||
    if not hasattr(device, "last_t"):
 | 
					 | 
				
			||||||
        return datetime.fromtimestamp(0, tz=timezone.utc)
 | 
					 | 
				
			||||||
    if isinstance(device.last_t, datetime):
 | 
					 | 
				
			||||||
        return device.last_t
 | 
					 | 
				
			||||||
    return datetime.fromtimestamp(device.last_t * device.step, tz=timezone.utc)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AuthenticatorValidateStageView(ChallengeStageView):
 | 
					class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			||||||
    """Authenticator Validation"""
 | 
					    """Authenticator Validation"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -146,31 +150,30 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
        challenges = []
 | 
					        challenges = []
 | 
				
			||||||
        # Convert to a list to have usable log output instead of just <generator ...>
 | 
					        # Convert to a list to have usable log output instead of just <generator ...>
 | 
				
			||||||
        user_devices = list(devices_for_user(self.get_pending_user()))
 | 
					        user_devices = list(devices_for_user(self.get_pending_user()))
 | 
				
			||||||
        LOGGER.debug("Got devices for user", devices=user_devices)
 | 
					        self.logger.debug("Got devices for user", devices=user_devices)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # static and totp are only shown once
 | 
					        # static and totp are only shown once
 | 
				
			||||||
        # since their challenges are device-independant
 | 
					        # since their challenges are device-independent
 | 
				
			||||||
        seen_classes = []
 | 
					        seen_classes = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stage: AuthenticatorValidateStage = self.executor.current_stage
 | 
					        stage: AuthenticatorValidateStage = self.executor.current_stage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        _now = now()
 | 
					 | 
				
			||||||
        threshold = timedelta_from_string(stage.last_auth_threshold)
 | 
					        threshold = timedelta_from_string(stage.last_auth_threshold)
 | 
				
			||||||
 | 
					        allowed_devices = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for device in user_devices:
 | 
					        for device in user_devices:
 | 
				
			||||||
            device_class = device.__class__.__name__.lower().replace("device", "")
 | 
					            device_class = device.__class__.__name__.lower().replace("device", "")
 | 
				
			||||||
            if device_class not in stage.device_classes:
 | 
					            if device_class not in stage.device_classes:
 | 
				
			||||||
                LOGGER.debug("device class not allowed", device_class=device_class)
 | 
					                self.logger.debug("device class not allowed", device_class=device_class)
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
 | 
					            if isinstance(device, SMSDevice) and device.is_hashed:
 | 
				
			||||||
 | 
					                LOGGER.debug("Hashed SMS device, skipping")
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            allowed_devices.append(device)
 | 
				
			||||||
            # Ensure only one challenge per device class
 | 
					            # Ensure only one challenge per device class
 | 
				
			||||||
            # WebAuthn does another device loop to find all webuahtn devices
 | 
					            # WebAuthn does another device loop to find all WebAuthn devices
 | 
				
			||||||
            if device_class in seen_classes:
 | 
					            if device_class in seen_classes:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            # check if device has been used within threshold and skip this stage if so
 | 
					 | 
				
			||||||
            if threshold.total_seconds() > 0:
 | 
					 | 
				
			||||||
                if _now - get_device_last_usage(device) <= threshold:
 | 
					 | 
				
			||||||
                    LOGGER.info("Device has been used within threshold", device=device)
 | 
					 | 
				
			||||||
                    raise FlowSkipStageException()
 | 
					 | 
				
			||||||
            if device_class not in seen_classes:
 | 
					            if device_class not in seen_classes:
 | 
				
			||||||
                seen_classes.append(device_class)
 | 
					                seen_classes.append(device_class)
 | 
				
			||||||
            challenge = DeviceChallenge(
 | 
					            challenge = DeviceChallenge(
 | 
				
			||||||
@ -182,16 +185,19 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
            challenge.is_valid()
 | 
					            challenge.is_valid()
 | 
				
			||||||
            challenges.append(challenge.data)
 | 
					            challenges.append(challenge.data)
 | 
				
			||||||
            LOGGER.debug("adding challenge for device", challenge=challenge)
 | 
					            self.logger.debug("adding challenge for device", challenge=challenge)
 | 
				
			||||||
 | 
					        # check if we have an MFA cookie and if it's valid
 | 
				
			||||||
 | 
					        if threshold.total_seconds() > 0:
 | 
				
			||||||
 | 
					            self.check_mfa_cookie(allowed_devices)
 | 
				
			||||||
        return challenges
 | 
					        return challenges
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_userless_webauthn_challenge(self) -> list[dict]:
 | 
					    def get_webauthn_challenge_without_user(self) -> list[dict]:
 | 
				
			||||||
        """Get a WebAuthn challenge when no pending user is set."""
 | 
					        """Get a WebAuthn challenge when no pending user is set."""
 | 
				
			||||||
        challenge = DeviceChallenge(
 | 
					        challenge = DeviceChallenge(
 | 
				
			||||||
            data={
 | 
					            data={
 | 
				
			||||||
                "device_class": DeviceClasses.WEBAUTHN,
 | 
					                "device_class": DeviceClasses.WEBAUTHN,
 | 
				
			||||||
                "device_uid": -1,
 | 
					                "device_uid": -1,
 | 
				
			||||||
                "challenge": get_webauthn_challenge_userless(self.request),
 | 
					                "challenge": get_webauthn_challenge_without_user(self.request),
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        challenge.is_valid()
 | 
					        challenge.is_valid()
 | 
				
			||||||
@ -210,27 +216,27 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
                return self.executor.stage_ok()
 | 
					                return self.executor.stage_ok()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if self.executor.flow.designation != FlowDesignation.AUTHENTICATION:
 | 
					            if self.executor.flow.designation != FlowDesignation.AUTHENTICATION:
 | 
				
			||||||
                LOGGER.debug("Refusing passwordless flow in non-authentication flow")
 | 
					                self.logger.debug("Refusing passwordless flow in non-authentication flow")
 | 
				
			||||||
                return self.executor.stage_ok()
 | 
					                return self.executor.stage_ok()
 | 
				
			||||||
            # Passwordless auth, with just webauthn
 | 
					            # Passwordless auth, with just webauthn
 | 
				
			||||||
            if DeviceClasses.WEBAUTHN in stage.device_classes:
 | 
					            if DeviceClasses.WEBAUTHN in stage.device_classes:
 | 
				
			||||||
                LOGGER.debug("Userless flow, getting generic webauthn challenge")
 | 
					                self.logger.debug("Flow without user, getting generic webauthn challenge")
 | 
				
			||||||
                challenges = self.get_userless_webauthn_challenge()
 | 
					                challenges = self.get_webauthn_challenge_without_user()
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                LOGGER.debug("No pending user, continuing")
 | 
					                self.logger.debug("No pending user, continuing")
 | 
				
			||||||
                return self.executor.stage_ok()
 | 
					                return self.executor.stage_ok()
 | 
				
			||||||
        self.request.session[SESSION_DEVICE_CHALLENGES] = challenges
 | 
					        self.request.session[SESSION_KEY_DEVICE_CHALLENGES] = challenges
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # No allowed devices
 | 
					        # No allowed devices
 | 
				
			||||||
        if len(challenges) < 1:
 | 
					        if len(challenges) < 1:
 | 
				
			||||||
            if stage.not_configured_action == NotConfiguredAction.SKIP:
 | 
					            if stage.not_configured_action == NotConfiguredAction.SKIP:
 | 
				
			||||||
                LOGGER.debug("Authenticator not configured, skipping stage")
 | 
					                self.logger.debug("Authenticator not configured, skipping stage")
 | 
				
			||||||
                return self.executor.stage_ok()
 | 
					                return self.executor.stage_ok()
 | 
				
			||||||
            if stage.not_configured_action == NotConfiguredAction.DENY:
 | 
					            if stage.not_configured_action == NotConfiguredAction.DENY:
 | 
				
			||||||
                LOGGER.debug("Authenticator not configured, denying")
 | 
					                self.logger.debug("Authenticator not configured, denying")
 | 
				
			||||||
                return self.executor.stage_invalid()
 | 
					                return self.executor.stage_invalid()
 | 
				
			||||||
            if stage.not_configured_action == NotConfiguredAction.CONFIGURE:
 | 
					            if stage.not_configured_action == NotConfiguredAction.CONFIGURE:
 | 
				
			||||||
                LOGGER.debug("Authenticator not configured, forcing configure")
 | 
					                self.logger.debug("Authenticator not configured, forcing configure")
 | 
				
			||||||
                return self.prepare_stages(user)
 | 
					                return self.prepare_stages(user)
 | 
				
			||||||
        return super().get(request, *args, **kwargs)
 | 
					        return super().get(request, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -251,24 +257,24 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
            return self.executor.stage_invalid()
 | 
					            return self.executor.stage_invalid()
 | 
				
			||||||
        if stage.configuration_stages.count() == 1:
 | 
					        if stage.configuration_stages.count() == 1:
 | 
				
			||||||
            next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk)
 | 
					            next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk)
 | 
				
			||||||
            LOGGER.debug("Single stage configured, auto-selecting", stage=next_stage)
 | 
					            self.logger.debug("Single stage configured, auto-selecting", stage=next_stage)
 | 
				
			||||||
            self.request.session[SESSION_SELECTED_STAGE] = next_stage
 | 
					            self.request.session[SESSION_KEY_SELECTED_STAGE] = next_stage
 | 
				
			||||||
            # Because that normal insetion only happens on post, we directly inject it here and
 | 
					            # Because that normal execution only happens on post, we directly inject it here and
 | 
				
			||||||
            # return it
 | 
					            # return it
 | 
				
			||||||
            self.executor.plan.insert_stage(next_stage)
 | 
					            self.executor.plan.insert_stage(next_stage)
 | 
				
			||||||
            return self.executor.stage_ok()
 | 
					            return self.executor.stage_ok()
 | 
				
			||||||
        stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses()
 | 
					        stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses()
 | 
				
			||||||
        self.request.session[SESSION_STAGES] = stages
 | 
					        self.request.session[SESSION_KEY_STAGES] = stages
 | 
				
			||||||
        return super().get(self.request, *args, **kwargs)
 | 
					        return super().get(self.request, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
					    def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
 | 
				
			||||||
        res = super().post(request, *args, **kwargs)
 | 
					        res = super().post(request, *args, **kwargs)
 | 
				
			||||||
        if (
 | 
					        if (
 | 
				
			||||||
            SESSION_SELECTED_STAGE in self.request.session
 | 
					            SESSION_KEY_SELECTED_STAGE in self.request.session
 | 
				
			||||||
            and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
 | 
					            and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            LOGGER.debug("Got selected stage in session, running that")
 | 
					            self.logger.debug("Got selected stage in session, running that")
 | 
				
			||||||
            stage_pk = self.request.session.get(SESSION_SELECTED_STAGE)
 | 
					            stage_pk = self.request.session.get(SESSION_KEY_SELECTED_STAGE)
 | 
				
			||||||
            # Because the foreign key to stage.configuration_stage points to
 | 
					            # Because the foreign key to stage.configuration_stage points to
 | 
				
			||||||
            # a base stage class, we need to do another lookup
 | 
					            # a base stage class, we need to do another lookup
 | 
				
			||||||
            stage = Stage.objects.get_subclass(pk=stage_pk)
 | 
					            stage = Stage.objects.get_subclass(pk=stage_pk)
 | 
				
			||||||
@ -279,8 +285,8 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
        return res
 | 
					        return res
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_challenge(self) -> AuthenticatorValidationChallenge:
 | 
					    def get_challenge(self) -> AuthenticatorValidationChallenge:
 | 
				
			||||||
        challenges = self.request.session.get(SESSION_DEVICE_CHALLENGES, [])
 | 
					        challenges = self.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, [])
 | 
				
			||||||
        stages = self.request.session.get(SESSION_STAGES, [])
 | 
					        stages = self.request.session.get(SESSION_KEY_STAGES, [])
 | 
				
			||||||
        stage_challenges = []
 | 
					        stage_challenges = []
 | 
				
			||||||
        for stage in stages:
 | 
					        for stage in stages:
 | 
				
			||||||
            serializer = SelectableStageSerializer(
 | 
					            serializer = SelectableStageSerializer(
 | 
				
			||||||
@ -301,15 +307,77 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def cookie_jwt_key(self) -> str:
 | 
				
			||||||
 | 
					        """Signing key for MFA Cookie for this stage"""
 | 
				
			||||||
 | 
					        return sha256(
 | 
				
			||||||
 | 
					            f"{settings.SECRET_KEY}:{self.executor.current_stage.pk.hex}".encode("ascii")
 | 
				
			||||||
 | 
					        ).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_mfa_cookie(self, allowed_devices: list[Device]):
 | 
				
			||||||
 | 
					        """Check if an MFA cookie has been set, whether it's valid and applies
 | 
				
			||||||
 | 
					        to the current stage and device.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The list of devices passed to this function must only contain devices for the
 | 
				
			||||||
 | 
					        correct user and with an allowed class"""
 | 
				
			||||||
 | 
					        if COOKIE_NAME_MFA not in self.request.COOKIES:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        stage: AuthenticatorValidateStage = self.executor.current_stage
 | 
				
			||||||
 | 
					        threshold = timedelta_from_string(stage.last_auth_threshold)
 | 
				
			||||||
 | 
					        latest_allowed = datetime.now() + threshold
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            payload = decode(self.request.COOKIES[COOKIE_NAME_MFA], self.cookie_jwt_key, ["HS256"])
 | 
				
			||||||
 | 
					            if payload["stage"] != stage.pk.hex:
 | 
				
			||||||
 | 
					                self.logger.warning("Invalid stage PK")
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					            if datetime.fromtimestamp(payload["exp"]) > latest_allowed:
 | 
				
			||||||
 | 
					                self.logger.warning("Expired MFA cookie")
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					            if not any(device.pk == payload["device"] for device in allowed_devices):
 | 
				
			||||||
 | 
					                self.logger.warning("Invalid device PK")
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					            self.logger.info("MFA has been used within threshold")
 | 
				
			||||||
 | 
					            raise FlowSkipStageException()
 | 
				
			||||||
 | 
					        except (PyJWTError, ValueError, TypeError) as exc:
 | 
				
			||||||
 | 
					            self.logger.info("Invalid mfa cookie for device", exc=exc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_valid_mfa_cookie(self, device: Device) -> HttpResponse:
 | 
				
			||||||
 | 
					        """Set an MFA cookie to allow users to skip MFA validation in this context (browser)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The cookie is JWT which is signed with a hash of the secret key and the UID of the stage"""
 | 
				
			||||||
 | 
					        stage: AuthenticatorValidateStage = self.executor.current_stage
 | 
				
			||||||
 | 
					        delta = timedelta_from_string(stage.last_auth_threshold)
 | 
				
			||||||
 | 
					        if delta.total_seconds() < 1:
 | 
				
			||||||
 | 
					            self.logger.info("Not setting MFA cookie since threshold is not set.")
 | 
				
			||||||
 | 
					            return self.executor.stage_ok()
 | 
				
			||||||
 | 
					        expiry = datetime.now() + delta
 | 
				
			||||||
 | 
					        cookie_payload = {
 | 
				
			||||||
 | 
					            "device": device.pk,
 | 
				
			||||||
 | 
					            "stage": stage.pk.hex,
 | 
				
			||||||
 | 
					            "exp": expiry.timestamp(),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        response = self.executor.stage_ok()
 | 
				
			||||||
 | 
					        cookie = encode(cookie_payload, self.cookie_jwt_key)
 | 
				
			||||||
 | 
					        response.set_cookie(
 | 
				
			||||||
 | 
					            COOKIE_NAME_MFA,
 | 
				
			||||||
 | 
					            cookie,
 | 
				
			||||||
 | 
					            expires=expiry,
 | 
				
			||||||
 | 
					            path="/",
 | 
				
			||||||
 | 
					            max_age=delta,
 | 
				
			||||||
 | 
					            domain=settings.SESSION_COOKIE_DOMAIN,
 | 
				
			||||||
 | 
					            samesite="Lax",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # pylint: disable=unused-argument
 | 
					    # pylint: disable=unused-argument
 | 
				
			||||||
    def challenge_valid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse:
 | 
					    def challenge_valid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse:
 | 
				
			||||||
        # All validation is done by the serializer
 | 
					        # All validation is done by the serializer
 | 
				
			||||||
        user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
 | 
					        user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
 | 
				
			||||||
        if not user:
 | 
					        if not user:
 | 
				
			||||||
            webauthn_device: WebAuthnDevice = response.data.get("webauthn", None)
 | 
					            if "webauthn" not in response.data:
 | 
				
			||||||
            if not webauthn_device:
 | 
					                return self.executor.stage_invalid()
 | 
				
			||||||
                return self.executor.stage_ok()
 | 
					            webauthn_device: WebAuthnDevice = response.device
 | 
				
			||||||
            LOGGER.debug("Set user from userless flow", user=webauthn_device.user)
 | 
					            self.logger.debug("Set user from user-less flow", user=webauthn_device.user)
 | 
				
			||||||
            self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = webauthn_device.user
 | 
					            self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = webauthn_device.user
 | 
				
			||||||
            self.executor.plan.context[PLAN_CONTEXT_METHOD] = "auth_webauthn_pwl"
 | 
					            self.executor.plan.context[PLAN_CONTEXT_METHOD] = "auth_webauthn_pwl"
 | 
				
			||||||
            self.executor.plan.context[PLAN_CONTEXT_METHOD_ARGS] = cleanse_dict(
 | 
					            self.executor.plan.context[PLAN_CONTEXT_METHOD_ARGS] = cleanse_dict(
 | 
				
			||||||
@ -319,4 +387,9 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
                    }
 | 
					                    }
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        return self.executor.stage_ok()
 | 
					        return self.set_valid_mfa_cookie(response.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def cleanup(self):
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_STAGES, None)
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_SELECTED_STAGE, None)
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_DEVICE_CHALLENGES, None)
 | 
				
			||||||
 | 
				
			|||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user