Compare commits
	
		
			2 Commits
		
	
	
		
			version/20
			...
			docs-vmwar
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7a6d44d0df | |||
| 63c3da169a | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2024.10.0-rc1 | ||||
| current_version = 2024.8.3 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/actions/setup/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -14,7 +14,7 @@ runs: | ||||
|       run: | | ||||
|         pipx install poetry || true | ||||
|         sudo apt-get update | ||||
|         sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext libkrb5-dev krb5-kdc krb5-user krb5-admin-server | ||||
|         sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext | ||||
|     - name: Setup python and restore poetry | ||||
|       uses: actions/setup-python@v5 | ||||
|       with: | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/dependabot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/dependabot.yml
									
									
									
									
										vendored
									
									
								
							| @ -23,6 +23,7 @@ updates: | ||||
|   - package-ecosystem: npm | ||||
|     directories: | ||||
|       - "/web" | ||||
|       - "/tests/wdio" | ||||
|       - "/web/sfe" | ||||
|     schedule: | ||||
|       interval: daily | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							| @ -1,7 +1,7 @@ | ||||
| <!-- | ||||
| 👋 Hi there! Welcome. | ||||
|  | ||||
| Please check the Contributing guidelines: https://docs.goauthentik.io/docs/developer-docs/#how-can-i-contribute | ||||
| Please check the Contributing guidelines: https://goauthentik.io/developer-docs/#how-can-i-contribute | ||||
| --> | ||||
|  | ||||
| ## Details | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -180,7 +180,7 @@ jobs: | ||||
|         uses: ./.github/actions/setup | ||||
|       - name: Setup e2e env (chrome, etc) | ||||
|         run: | | ||||
|           docker compose -f tests/e2e/docker-compose.yml up -d --quiet-pull | ||||
|           docker compose -f tests/e2e/docker-compose.yml up -d | ||||
|       - id: cache-web | ||||
|         uses: actions/cache@v4 | ||||
|         with: | ||||
|  | ||||
							
								
								
									
										21
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							| @ -24,11 +24,17 @@ jobs: | ||||
|           - prettier-check | ||||
|         project: | ||||
|           - web | ||||
|           - tests/wdio | ||||
|         include: | ||||
|           - command: tsc | ||||
|             project: web | ||||
|           - command: lit-analyse | ||||
|             project: web | ||||
|         exclude: | ||||
|           - command: lint:lockfile | ||||
|             project: tests/wdio | ||||
|           - command: tsc | ||||
|             project: tests/wdio | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - uses: actions/setup-node@v4 | ||||
| @ -44,7 +50,15 @@ jobs: | ||||
|       - name: Lint | ||||
|         working-directory: ${{ matrix.project }}/ | ||||
|         run: npm run ${{ matrix.command }} | ||||
|   ci-web-mark: | ||||
|     needs: | ||||
|       - lint | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - run: echo mark | ||||
|   build: | ||||
|     needs: | ||||
|       - ci-web-mark | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
| @ -60,13 +74,6 @@ jobs: | ||||
|       - name: build | ||||
|         working-directory: web/ | ||||
|         run: npm run build | ||||
|   ci-web-mark: | ||||
|     needs: | ||||
|       - build | ||||
|       - lint | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - run: echo mark | ||||
|   test: | ||||
|     needs: | ||||
|       - ci-web-mark | ||||
|  | ||||
							
								
								
									
										1
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							| @ -6,7 +6,6 @@ | ||||
|         "authn", | ||||
|         "entra", | ||||
|         "goauthentik", | ||||
|         "jwe", | ||||
|         "jwks", | ||||
|         "kubernetes", | ||||
|         "oidc", | ||||
|  | ||||
| @ -94,7 +94,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | ||||
|     /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||
|  | ||||
| # Stage 5: Python dependencies | ||||
| FROM ghcr.io/goauthentik/fips-python:3.12.7-slim-bookworm-fips-full AS python-deps | ||||
| FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS python-deps | ||||
|  | ||||
| ARG TARGETARCH | ||||
| ARG TARGETVARIANT | ||||
| @ -110,7 +110,7 @@ RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloa | ||||
| RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \ | ||||
|     apt-get update && \ | ||||
|     # Required for installing pip packages | ||||
|     apt-get install -y --no-install-recommends build-essential pkg-config libpq-dev libkrb5-dev | ||||
|     apt-get install -y --no-install-recommends build-essential pkg-config libpq-dev | ||||
|  | ||||
| RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | ||||
|     --mount=type=bind,target=./poetry.lock,src=./poetry.lock \ | ||||
| @ -124,7 +124,7 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | ||||
|     pip install --force-reinstall /wheels/*" | ||||
|  | ||||
| # Stage 6: Run | ||||
| FROM ghcr.io/goauthentik/fips-python:3.12.7-slim-bookworm-fips-full AS final-image | ||||
| FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS final-image | ||||
|  | ||||
| ARG VERSION | ||||
| ARG GIT_BUILD_HASH | ||||
| @ -141,7 +141,7 @@ WORKDIR / | ||||
| # We cannot cache this layer otherwise we'll end up with a bigger image | ||||
| RUN apt-get update && \ | ||||
|     # Required for runtime | ||||
|     apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates libkrb5-3 libkadm5clnt-mit12 libkdb5-10 && \ | ||||
|     apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates && \ | ||||
|     # Required for bootstrap & healtcheck | ||||
|     apt-get install -y --no-install-recommends runit && \ | ||||
|     apt-get clean && \ | ||||
| @ -161,7 +161,6 @@ COPY ./tests /tests | ||||
| COPY ./manage.py / | ||||
| COPY ./blueprints /blueprints | ||||
| COPY ./lifecycle/ /lifecycle | ||||
| COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf | ||||
| COPY --from=go-builder /go/authentik /bin/authentik | ||||
| COPY --from=python-deps /ak-root/venv /ak-root/venv | ||||
| COPY --from=web-builder /work/web/dist/ /web/dist/ | ||||
|  | ||||
							
								
								
									
										3
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								Makefile
									
									
									
									
									
								
							| @ -19,13 +19,14 @@ pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null) | ||||
| CODESPELL_ARGS = -D - -D .github/codespell-dictionary.txt \ | ||||
| 		-I .github/codespell-words.txt \ | ||||
| 		-S 'web/src/locales/**' \ | ||||
| 		-S 'website/docs/developer-docs/api/reference/**' \ | ||||
| 		-S 'website/developer-docs/api/reference/**' \ | ||||
| 		authentik \ | ||||
| 		internal \ | ||||
| 		cmd \ | ||||
| 		web/src \ | ||||
| 		website/src \ | ||||
| 		website/blog \ | ||||
| 		website/developer-docs \ | ||||
| 		website/docs \ | ||||
| 		website/integrations \ | ||||
| 		website/src | ||||
|  | ||||
| @ -34,7 +34,7 @@ For bigger setups, there is a Helm Chart [here](https://github.com/goauthentik/h | ||||
|  | ||||
| ## Development | ||||
|  | ||||
| See [Developer Documentation](https://docs.goauthentik.io/docs/developer-docs/?utm_source=github) | ||||
| See [Developer Documentation](https://goauthentik.io/developer-docs/?utm_source=github) | ||||
|  | ||||
| ## Security | ||||
|  | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from os import environ | ||||
|  | ||||
| __version__ = "2024.10.0" | ||||
| __version__ = "2024.8.3" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -1,33 +0,0 @@ | ||||
| from rest_framework.permissions import IsAdminUser | ||||
| from rest_framework.viewsets import ReadOnlyModelViewSet | ||||
|  | ||||
| from authentik.admin.models import VersionHistory | ||||
| from authentik.core.api.utils import ModelSerializer | ||||
|  | ||||
|  | ||||
| class VersionHistorySerializer(ModelSerializer): | ||||
|     """VersionHistory Serializer""" | ||||
|  | ||||
|     class Meta: | ||||
|         model = VersionHistory | ||||
|         fields = [ | ||||
|             "id", | ||||
|             "timestamp", | ||||
|             "version", | ||||
|             "build", | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class VersionHistoryViewSet(ReadOnlyModelViewSet): | ||||
|     """VersionHistory Viewset""" | ||||
|  | ||||
|     queryset = VersionHistory.objects.all() | ||||
|     serializer_class = VersionHistorySerializer | ||||
|     permission_classes = [IsAdminUser] | ||||
|     filterset_fields = [ | ||||
|         "version", | ||||
|         "build", | ||||
|     ] | ||||
|     search_fields = ["version", "build"] | ||||
|     ordering = ["-timestamp"] | ||||
|     pagination_class = None | ||||
| @ -1,22 +0,0 @@ | ||||
| """authentik admin models""" | ||||
|  | ||||
| from django.db import models | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
|  | ||||
| class VersionHistory(models.Model): | ||||
|     id = models.BigAutoField(primary_key=True) | ||||
|     timestamp = models.DateTimeField() | ||||
|     version = models.TextField() | ||||
|     build = models.TextField() | ||||
|  | ||||
|     class Meta: | ||||
|         managed = False | ||||
|         db_table = "authentik_version_history" | ||||
|         ordering = ("-timestamp",) | ||||
|         verbose_name = _("Version history") | ||||
|         verbose_name_plural = _("Version history") | ||||
|         default_permissions = [] | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"{self.version}.{self.build} ({self.timestamp})" | ||||
| @ -6,7 +6,6 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet | ||||
| from authentik.admin.api.metrics import AdministrationMetricsViewSet | ||||
| from authentik.admin.api.system import SystemView | ||||
| from authentik.admin.api.version import VersionView | ||||
| from authentik.admin.api.version_history import VersionHistoryViewSet | ||||
| from authentik.admin.api.workers import WorkerView | ||||
|  | ||||
| api_urlpatterns = [ | ||||
| @ -18,7 +17,6 @@ api_urlpatterns = [ | ||||
|         name="admin_metrics", | ||||
|     ), | ||||
|     path("admin/version/", VersionView.as_view(), name="admin_version"), | ||||
|     ("admin/version/history", VersionHistoryViewSet, "version_history"), | ||||
|     path("admin/workers/", WorkerView.as_view(), name="admin_workers"), | ||||
|     path("admin/system/", SystemView.as_view(), name="admin_system"), | ||||
| ] | ||||
|  | ||||
| @ -51,11 +51,9 @@ class BlueprintInstanceSerializer(ModelSerializer): | ||||
|         context = self.instance.context if self.instance else {} | ||||
|         valid, logs = Importer.from_string(content, context).validate() | ||||
|         if not valid: | ||||
|             text_logs = "\n".join([x["event"] for x in logs]) | ||||
|             raise ValidationError( | ||||
|                 [ | ||||
|                     _("Failed to validate blueprint"), | ||||
|                     *[f"- {x.event}" for x in logs], | ||||
|                 ] | ||||
|                 _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs})) | ||||
|             ) | ||||
|         return content | ||||
|  | ||||
|  | ||||
| @ -29,7 +29,9 @@ def check_blueprint_v1_file(BlueprintInstance: type, db_alias, path: Path): | ||||
|         if version != 1: | ||||
|             return | ||||
|         blueprint_file.seek(0) | ||||
|     instance = BlueprintInstance.objects.using(db_alias).filter(path=path).first() | ||||
|     instance: BlueprintInstance = ( | ||||
|         BlueprintInstance.objects.using(db_alias).filter(path=path).first() | ||||
|     ) | ||||
|     rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir"))) | ||||
|     meta = None | ||||
|     if metadata: | ||||
|  | ||||
| @ -78,5 +78,5 @@ class TestBlueprintsV1API(APITestCase): | ||||
|         self.assertEqual(res.status_code, 400) | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             {"content": ["Failed to validate blueprint", "- Invalid blueprint version"]}, | ||||
|             {"content": ["Failed to validate blueprint: Invalid blueprint version"]}, | ||||
|         ) | ||||
|  | ||||
| @ -51,10 +51,6 @@ from authentik.enterprise.providers.microsoft_entra.models import ( | ||||
|     MicrosoftEntraProviderUser, | ||||
| ) | ||||
| from authentik.enterprise.providers.rac.models import ConnectionToken | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | ||||
|     EndpointDevice, | ||||
|     EndpointDeviceConnection, | ||||
| ) | ||||
| from authentik.events.logs import LogEvent, capture_logs | ||||
| from authentik.events.models import SystemTask | ||||
| from authentik.events.utils import cleanse_dict | ||||
| @ -73,7 +69,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
| # Context set when the serializer is created in a blueprint context | ||||
| # Update website/docs/customize/blueprints/v1/models.md when used | ||||
| # Update website/developer-docs/blueprints/v1/models.md when used | ||||
| SERIALIZER_CONTEXT_BLUEPRINT = "blueprint_entry" | ||||
|  | ||||
|  | ||||
| @ -123,8 +119,6 @@ def excluded_models() -> list[type[Model]]: | ||||
|         GoogleWorkspaceProviderGroup, | ||||
|         MicrosoftEntraProviderUser, | ||||
|         MicrosoftEntraProviderGroup, | ||||
|         EndpointDevice, | ||||
|         EndpointDeviceConnection, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @ -435,7 +429,7 @@ class Importer: | ||||
|         orig_import = deepcopy(self._import) | ||||
|         if self._import.version != 1: | ||||
|             self.logger.warning("Invalid blueprint version") | ||||
|             return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)] | ||||
|             return False, [{"event": "Invalid blueprint version"}] | ||||
|         with ( | ||||
|             transaction_rollback(), | ||||
|             capture_logs() as logs, | ||||
|  | ||||
| @ -6,45 +6,34 @@ from rest_framework.fields import ( | ||||
|     BooleanField, | ||||
|     CharField, | ||||
|     DateTimeField, | ||||
|     IntegerField, | ||||
|     SerializerMethodField, | ||||
| ) | ||||
| from rest_framework.permissions import IsAuthenticated | ||||
| from rest_framework.permissions import IsAdminUser, IsAuthenticated | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.viewsets import ViewSet | ||||
|  | ||||
| from authentik.core.api.utils import MetaNameSerializer | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice | ||||
| from authentik.rbac.decorators import permission_required | ||||
| from authentik.stages.authenticator import device_classes, devices_for_user | ||||
| from authentik.stages.authenticator.models import Device | ||||
| from authentik.stages.authenticator_webauthn.models import WebAuthnDevice | ||||
|  | ||||
|  | ||||
| class DeviceSerializer(MetaNameSerializer): | ||||
|     """Serializer for Duo authenticator devices""" | ||||
|  | ||||
|     pk = CharField() | ||||
|     pk = IntegerField() | ||||
|     name = CharField() | ||||
|     type = SerializerMethodField() | ||||
|     confirmed = BooleanField() | ||||
|     created = DateTimeField(read_only=True) | ||||
|     last_updated = DateTimeField(read_only=True) | ||||
|     last_used = DateTimeField(read_only=True, allow_null=True) | ||||
|     extra_description = SerializerMethodField() | ||||
|  | ||||
|     def get_type(self, instance: Device) -> str: | ||||
|         """Get type of device""" | ||||
|         return instance._meta.label | ||||
|  | ||||
|     def get_extra_description(self, instance: Device) -> str: | ||||
|         """Get extra description""" | ||||
|         if isinstance(instance, WebAuthnDevice): | ||||
|             return instance.device_type.description | ||||
|         if isinstance(instance, EndpointDevice): | ||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class DeviceViewSet(ViewSet): | ||||
|     """Viewset for authenticator devices""" | ||||
| @ -63,7 +52,7 @@ class AdminDeviceViewSet(ViewSet): | ||||
|     """Viewset for authenticator devices""" | ||||
|  | ||||
|     serializer_class = DeviceSerializer | ||||
|     permission_classes = [] | ||||
|     permission_classes = [IsAdminUser] | ||||
|  | ||||
|     def get_devices(self, **kwargs): | ||||
|         """Get all devices in all child classes""" | ||||
| @ -81,10 +70,6 @@ class AdminDeviceViewSet(ViewSet): | ||||
|         ], | ||||
|         responses={200: DeviceSerializer(many=True)}, | ||||
|     ) | ||||
|     @permission_required( | ||||
|         None, | ||||
|         [f"{model._meta.app_label}.view_{model._meta.model_name}" for model in device_classes()], | ||||
|     ) | ||||
|     def list(self, request: Request) -> Response: | ||||
|         """Get all devices for current user""" | ||||
|         kwargs = {} | ||||
|  | ||||
| @ -38,7 +38,6 @@ class ProviderSerializer(ModelSerializer, MetaNameSerializer): | ||||
|             "name", | ||||
|             "authentication_flow", | ||||
|             "authorization_flow", | ||||
|             "invalidation_flow", | ||||
|             "property_mappings", | ||||
|             "component", | ||||
|             "assigned_application_slug", | ||||
| @ -51,7 +50,6 @@ class ProviderSerializer(ModelSerializer, MetaNameSerializer): | ||||
|         ] | ||||
|         extra_kwargs = { | ||||
|             "authorization_flow": {"required": True, "allow_null": False}, | ||||
|             "invalidation_flow": {"required": True, "allow_null": False}, | ||||
|         } | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -679,10 +679,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|             LOGGER.debug("User attempted to impersonate", user=request.user) | ||||
|             return Response(status=401) | ||||
|         user_to_be = self.get_object() | ||||
|         # Check both object-level perms and global perms | ||||
|         if not request.user.has_perm( | ||||
|             "authentik_core.impersonate", user_to_be | ||||
|         ) and not request.user.has_perm("authentik_core.impersonate"): | ||||
|         if not request.user.has_perm("impersonate", user_to_be): | ||||
|             LOGGER.debug("User attempted to impersonate without permissions", user=request.user) | ||||
|             return Response(status=401) | ||||
|         if user_to_be.pk == self.request.user.pk: | ||||
|  | ||||
| @ -4,7 +4,6 @@ import code | ||||
| import platform | ||||
| import sys | ||||
| import traceback | ||||
| from pprint import pprint | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.core.management.base import BaseCommand | ||||
| @ -35,9 +34,7 @@ class Command(BaseCommand): | ||||
|  | ||||
|     def get_namespace(self): | ||||
|         """Prepare namespace with all models""" | ||||
|         namespace = { | ||||
|             "pprint": pprint, | ||||
|         } | ||||
|         namespace = {} | ||||
|  | ||||
|         # Gather Django models and constants from each app | ||||
|         for app in apps.get_app_configs(): | ||||
|  | ||||
| @ -1,55 +0,0 @@ | ||||
| # Generated by Django 5.0.9 on 2024-10-02 11:35 | ||||
|  | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
|  | ||||
| from django.apps.registry import Apps | ||||
| from django.db import migrations, models | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
|  | ||||
| def migrate_invalidation_flow_default(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     from authentik.flows.models import FlowDesignation, FlowAuthenticationRequirement | ||||
|  | ||||
|     db_alias = schema_editor.connection.alias | ||||
|  | ||||
|     Flow = apps.get_model("authentik_flows", "Flow") | ||||
|     Provider = apps.get_model("authentik_core", "Provider") | ||||
|  | ||||
|     # So this flow is managed via a blueprint, bue we're in a migration so we don't want to rely on that | ||||
|     # since the blueprint is just an empty flow we can just create it here | ||||
|     # and let it be managed by the blueprint later | ||||
|     flow, _ = Flow.objects.using(db_alias).update_or_create( | ||||
|         slug="default-provider-invalidation-flow", | ||||
|         defaults={ | ||||
|             "name": "Logged out of application", | ||||
|             "title": "You've logged out of %(app)s.", | ||||
|             "authentication": FlowAuthenticationRequirement.NONE, | ||||
|             "designation": FlowDesignation.INVALIDATION, | ||||
|         }, | ||||
|     ) | ||||
|     Provider.objects.using(db_alias).filter(invalidation_flow=None).update(invalidation_flow=flow) | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"), | ||||
|         ("authentik_flows", "0027_auto_20231028_1424"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="provider", | ||||
|             name="invalidation_flow", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 help_text="Flow used ending the session from a provider.", | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 related_name="provider_invalidation", | ||||
|                 to="authentik_flows.flow", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.RunPython(migrate_invalidation_flow_default), | ||||
|     ] | ||||
| @ -330,13 +330,11 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser): | ||||
|         """superuser == staff user""" | ||||
|         return self.is_superuser  # type: ignore | ||||
|  | ||||
|     def set_password(self, raw_password, signal=True, sender=None): | ||||
|     def set_password(self, raw_password, signal=True): | ||||
|         if self.pk and signal: | ||||
|             from authentik.core.signals import password_changed | ||||
|  | ||||
|             if not sender: | ||||
|                 sender = self | ||||
|             password_changed.send(sender=sender, user=self, password=raw_password) | ||||
|             password_changed.send(sender=self, user=self, password=raw_password) | ||||
|         self.password_change_date = now() | ||||
|         return super().set_password(raw_password) | ||||
|  | ||||
| @ -393,23 +391,14 @@ class Provider(SerializerModel): | ||||
|         ), | ||||
|         related_name="provider_authentication", | ||||
|     ) | ||||
|  | ||||
|     authorization_flow = models.ForeignKey( | ||||
|         "authentik_flows.Flow", | ||||
|         # Set to cascade even though null is allowed, since most providers | ||||
|         # still require an authorization flow set | ||||
|         on_delete=models.CASCADE, | ||||
|         null=True, | ||||
|         help_text=_("Flow used when authorizing this provider."), | ||||
|         related_name="provider_authorization", | ||||
|     ) | ||||
|     invalidation_flow = models.ForeignKey( | ||||
|         "authentik_flows.Flow", | ||||
|         on_delete=models.SET_DEFAULT, | ||||
|         default=None, | ||||
|         null=True, | ||||
|         help_text=_("Flow used ending the session from a provider."), | ||||
|         related_name="provider_invalidation", | ||||
|     ) | ||||
|  | ||||
|     property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True) | ||||
|  | ||||
|  | ||||
| @ -1,9 +1,11 @@ | ||||
| """Source decision helper""" | ||||
|  | ||||
| from enum import Enum | ||||
| from typing import Any | ||||
|  | ||||
| from django.contrib import messages | ||||
| from django.db import IntegrityError, transaction | ||||
| from django.db.models.query_utils import Q | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.shortcuts import redirect | ||||
| from django.urls import reverse | ||||
| @ -14,11 +16,12 @@ from authentik.core.models import ( | ||||
|     Group, | ||||
|     GroupSourceConnection, | ||||
|     Source, | ||||
|     SourceGroupMatchingModes, | ||||
|     SourceUserMatchingModes, | ||||
|     User, | ||||
|     UserSourceConnection, | ||||
| ) | ||||
| from authentik.core.sources.mapper import SourceMapper | ||||
| from authentik.core.sources.matcher import Action, SourceMatcher | ||||
| from authentik.core.sources.stage import ( | ||||
|     PLAN_CONTEXT_SOURCES_CONNECTION, | ||||
|     PostSourceStage, | ||||
| @ -51,6 +54,16 @@ SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token" | ||||
| PLAN_CONTEXT_SOURCE_GROUPS = "source_groups" | ||||
|  | ||||
|  | ||||
| class Action(Enum): | ||||
|     """Actions that can be decided based on the request | ||||
|     and source settings""" | ||||
|  | ||||
|     LINK = "link" | ||||
|     AUTH = "auth" | ||||
|     ENROLL = "enroll" | ||||
|     DENY = "deny" | ||||
|  | ||||
|  | ||||
| class MessageStage(StageView): | ||||
|     """Show a pre-configured message after the flow is done""" | ||||
|  | ||||
| @ -73,7 +86,6 @@ class SourceFlowManager: | ||||
|  | ||||
|     source: Source | ||||
|     mapper: SourceMapper | ||||
|     matcher: SourceMatcher | ||||
|     request: HttpRequest | ||||
|  | ||||
|     identifier: str | ||||
| @ -96,9 +108,6 @@ class SourceFlowManager: | ||||
|     ) -> None: | ||||
|         self.source = source | ||||
|         self.mapper = SourceMapper(self.source) | ||||
|         self.matcher = SourceMatcher( | ||||
|             self.source, self.user_connection_type, self.group_connection_type | ||||
|         ) | ||||
|         self.request = request | ||||
|         self.identifier = identifier | ||||
|         self.user_info = user_info | ||||
| @ -122,19 +131,66 @@ class SourceFlowManager: | ||||
|  | ||||
|     def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]:  # noqa: PLR0911 | ||||
|         """decide which action should be taken""" | ||||
|         new_connection = self.user_connection_type(source=self.source, identifier=self.identifier) | ||||
|         # When request is authenticated, always link | ||||
|         if self.request.user.is_authenticated: | ||||
|             new_connection = self.user_connection_type( | ||||
|                 source=self.source, identifier=self.identifier | ||||
|             ) | ||||
|             new_connection.user = self.request.user | ||||
|             new_connection = self.update_user_connection(new_connection, **kwargs) | ||||
|             return Action.LINK, new_connection | ||||
|  | ||||
|         action, connection = self.matcher.get_user_action(self.identifier, self.user_properties) | ||||
|         if connection: | ||||
|             connection = self.update_user_connection(connection, **kwargs) | ||||
|         return action, connection | ||||
|         existing_connections = self.user_connection_type.objects.filter( | ||||
|             source=self.source, identifier=self.identifier | ||||
|         ) | ||||
|         if existing_connections.exists(): | ||||
|             connection = existing_connections.first() | ||||
|             return Action.AUTH, self.update_user_connection(connection, **kwargs) | ||||
|         # No connection exists, but we match on identifier, so enroll | ||||
|         if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER: | ||||
|             # We don't save the connection here cause it doesn't have a user assigned yet | ||||
|             return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) | ||||
|  | ||||
|         # Check for existing users with matching attributes | ||||
|         query = Q() | ||||
|         # Either query existing user based on email or username | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.EMAIL_LINK, | ||||
|             SourceUserMatchingModes.EMAIL_DENY, | ||||
|         ]: | ||||
|             if not self.user_properties.get("email", None): | ||||
|                 self._logger.warning("Refusing to use none email") | ||||
|                 return Action.DENY, None | ||||
|             query = Q(email__exact=self.user_properties.get("email", None)) | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.USERNAME_LINK, | ||||
|             SourceUserMatchingModes.USERNAME_DENY, | ||||
|         ]: | ||||
|             if not self.user_properties.get("username", None): | ||||
|                 self._logger.warning("Refusing to use none username") | ||||
|                 return Action.DENY, None | ||||
|             query = Q(username__exact=self.user_properties.get("username", None)) | ||||
|         self._logger.debug("trying to link with existing user", query=query) | ||||
|         matching_users = User.objects.filter(query) | ||||
|         # No matching users, always enroll | ||||
|         if not matching_users.exists(): | ||||
|             self._logger.debug("no matching users found, enrolling") | ||||
|             return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) | ||||
|  | ||||
|         user = matching_users.first() | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.EMAIL_LINK, | ||||
|             SourceUserMatchingModes.USERNAME_LINK, | ||||
|         ]: | ||||
|             new_connection.user = user | ||||
|             new_connection = self.update_user_connection(new_connection, **kwargs) | ||||
|             return Action.LINK, new_connection | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.EMAIL_DENY, | ||||
|             SourceUserMatchingModes.USERNAME_DENY, | ||||
|         ]: | ||||
|             self._logger.info("denying source because user exists", user=user) | ||||
|             return Action.DENY, None | ||||
|         # Should never get here as default enroll case is returned above. | ||||
|         return Action.DENY, None  # pragma: no cover | ||||
|  | ||||
|     def update_user_connection( | ||||
|         self, connection: UserSourceConnection, **kwargs | ||||
| @ -272,6 +328,7 @@ class SourceFlowManager: | ||||
|         connection: UserSourceConnection, | ||||
|     ) -> HttpResponse: | ||||
|         """Login user and redirect.""" | ||||
|         flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} | ||||
|         return self._prepare_flow( | ||||
|             self.source.authentication_flow, | ||||
|             connection, | ||||
| @ -285,11 +342,7 @@ class SourceFlowManager: | ||||
|                     ), | ||||
|                 ) | ||||
|             ], | ||||
|             **{ | ||||
|                 PLAN_CONTEXT_PENDING_USER: connection.user, | ||||
|                 PLAN_CONTEXT_PROMPT: delete_none_values(self.user_properties), | ||||
|                 PLAN_CONTEXT_USER_PATH: self.source.get_user_path(), | ||||
|             }, | ||||
|             **flow_kwargs, | ||||
|         ) | ||||
|  | ||||
|     def handle_existing_link( | ||||
| @ -355,16 +408,74 @@ class SourceFlowManager: | ||||
| class GroupUpdateStage(StageView): | ||||
|     """Dynamically injected stage which updates the user after enrollment/authentication.""" | ||||
|  | ||||
|     def get_action( | ||||
|         self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> tuple[Action, GroupSourceConnection | None]: | ||||
|         """decide which action should be taken""" | ||||
|         new_connection = self.group_connection_type(source=self.source, identifier=group_id) | ||||
|  | ||||
|         existing_connections = self.group_connection_type.objects.filter( | ||||
|             source=self.source, identifier=group_id | ||||
|         ) | ||||
|         if existing_connections.exists(): | ||||
|             return Action.LINK, existing_connections.first() | ||||
|         # No connection exists, but we match on identifier, so enroll | ||||
|         if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER: | ||||
|             # We don't save the connection here cause it doesn't have a user assigned yet | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         # Check for existing groups with matching attributes | ||||
|         query = Q() | ||||
|         if self.source.group_matching_mode in [ | ||||
|             SourceGroupMatchingModes.NAME_LINK, | ||||
|             SourceGroupMatchingModes.NAME_DENY, | ||||
|         ]: | ||||
|             if not group_properties.get("name", None): | ||||
|                 LOGGER.warning( | ||||
|                     "Refusing to use none group name", source=self.source, group_id=group_id | ||||
|                 ) | ||||
|                 return Action.DENY, None | ||||
|             query = Q(name__exact=group_properties.get("name")) | ||||
|         LOGGER.debug( | ||||
|             "trying to link with existing group", source=self.source, query=query, group_id=group_id | ||||
|         ) | ||||
|         matching_groups = Group.objects.filter(query) | ||||
|         # No matching groups, always enroll | ||||
|         if not matching_groups.exists(): | ||||
|             LOGGER.debug( | ||||
|                 "no matching groups found, enrolling", source=self.source, group_id=group_id | ||||
|             ) | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         group = matching_groups.first() | ||||
|         if self.source.group_matching_mode in [ | ||||
|             SourceGroupMatchingModes.NAME_LINK, | ||||
|         ]: | ||||
|             new_connection.group = group | ||||
|             return Action.LINK, new_connection | ||||
|         if self.source.group_matching_mode in [ | ||||
|             SourceGroupMatchingModes.NAME_DENY, | ||||
|         ]: | ||||
|             LOGGER.info( | ||||
|                 "denying source because group exists", | ||||
|                 source=self.source, | ||||
|                 group=group, | ||||
|                 group_id=group_id, | ||||
|             ) | ||||
|             return Action.DENY, None | ||||
|         # Should never get here as default enroll case is returned above. | ||||
|         return Action.DENY, None  # pragma: no cover | ||||
|  | ||||
|     def handle_group( | ||||
|         self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> Group | None: | ||||
|         action, connection = self.matcher.get_group_action(group_id, group_properties) | ||||
|         action, connection = self.get_action(group_id, group_properties) | ||||
|         if action == Action.ENROLL: | ||||
|             group = Group.objects.create(**group_properties) | ||||
|             connection.group = group | ||||
|             connection.save() | ||||
|             return group | ||||
|         elif action in (Action.LINK, Action.AUTH): | ||||
|         elif action == Action.LINK: | ||||
|             group = connection.group | ||||
|             group.update_attributes(group_properties) | ||||
|             connection.save() | ||||
| @ -378,7 +489,6 @@ class GroupUpdateStage(StageView): | ||||
|         self.group_connection_type: GroupSourceConnection = ( | ||||
|             self.executor.current_stage.group_connection_type | ||||
|         ) | ||||
|         self.matcher = SourceMatcher(self.source, None, self.group_connection_type) | ||||
|  | ||||
|         raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[ | ||||
|             PLAN_CONTEXT_SOURCE_GROUPS | ||||
|  | ||||
| @ -1,152 +0,0 @@ | ||||
| """Source user and group matching""" | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from enum import Enum | ||||
| from typing import Any | ||||
|  | ||||
| from django.db.models import Q | ||||
| from structlog import get_logger | ||||
|  | ||||
| from authentik.core.models import ( | ||||
|     Group, | ||||
|     GroupSourceConnection, | ||||
|     Source, | ||||
|     SourceGroupMatchingModes, | ||||
|     SourceUserMatchingModes, | ||||
|     User, | ||||
|     UserSourceConnection, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class Action(Enum): | ||||
|     """Actions that can be decided based on the request and source settings""" | ||||
|  | ||||
|     LINK = "link" | ||||
|     AUTH = "auth" | ||||
|     ENROLL = "enroll" | ||||
|     DENY = "deny" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class MatchableProperty: | ||||
|     property: str | ||||
|     link_mode: SourceUserMatchingModes | SourceGroupMatchingModes | ||||
|     deny_mode: SourceUserMatchingModes | SourceGroupMatchingModes | ||||
|  | ||||
|  | ||||
| class SourceMatcher: | ||||
|     def __init__( | ||||
|         self, | ||||
|         source: Source, | ||||
|         user_connection_type: type[UserSourceConnection], | ||||
|         group_connection_type: type[GroupSourceConnection], | ||||
|     ): | ||||
|         self.source = source | ||||
|         self.user_connection_type = user_connection_type | ||||
|         self.group_connection_type = group_connection_type | ||||
|         self._logger = get_logger().bind(source=self.source) | ||||
|  | ||||
|     def get_action( | ||||
|         self, | ||||
|         object_type: type[User | Group], | ||||
|         matchable_properties: list[MatchableProperty], | ||||
|         identifier: str, | ||||
|         properties: dict[str, Any | dict[str, Any]], | ||||
|     ) -> tuple[Action, UserSourceConnection | GroupSourceConnection | None]: | ||||
|         connection_type = None | ||||
|         matching_mode = None | ||||
|         identifier_matching_mode = None | ||||
|         if object_type == User: | ||||
|             connection_type = self.user_connection_type | ||||
|             matching_mode = self.source.user_matching_mode | ||||
|             identifier_matching_mode = SourceUserMatchingModes.IDENTIFIER | ||||
|         if object_type == Group: | ||||
|             connection_type = self.group_connection_type | ||||
|             matching_mode = self.source.group_matching_mode | ||||
|             identifier_matching_mode = SourceGroupMatchingModes.IDENTIFIER | ||||
|         if not connection_type or not matching_mode or not identifier_matching_mode: | ||||
|             return Action.DENY, None | ||||
|  | ||||
|         new_connection = connection_type(source=self.source, identifier=identifier) | ||||
|  | ||||
|         existing_connections = connection_type.objects.filter( | ||||
|             source=self.source, identifier=identifier | ||||
|         ) | ||||
|         if existing_connections.exists(): | ||||
|             return Action.AUTH, existing_connections.first() | ||||
|         # No connection exists, but we match on identifier, so enroll | ||||
|         if matching_mode == identifier_matching_mode: | ||||
|             # We don't save the connection here cause it doesn't have a user/group assigned yet | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         # Check for existing users with matching attributes | ||||
|         query = Q() | ||||
|         for matchable_property in matchable_properties: | ||||
|             property = matchable_property.property | ||||
|             if matching_mode in [matchable_property.link_mode, matchable_property.deny_mode]: | ||||
|                 if not properties.get(property, None): | ||||
|                     self._logger.warning( | ||||
|                         "Refusing to use none property", identifier=identifier, property=property | ||||
|                     ) | ||||
|                     return Action.DENY, None | ||||
|                 query_args = { | ||||
|                     f"{property}__exact": properties[property], | ||||
|                 } | ||||
|                 query = Q(**query_args) | ||||
|         self._logger.debug( | ||||
|             "Trying to link with existing object", query=query, identifier=identifier | ||||
|         ) | ||||
|         matching_objects = object_type.objects.filter(query) | ||||
|         # Not matching objects, always enroll | ||||
|         if not matching_objects.exists(): | ||||
|             self._logger.debug("No matching objects found, enrolling") | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         obj = matching_objects.first() | ||||
|         if matching_mode in [mp.link_mode for mp in matchable_properties]: | ||||
|             attr = None | ||||
|             if object_type == User: | ||||
|                 attr = "user" | ||||
|             if object_type == Group: | ||||
|                 attr = "group" | ||||
|             setattr(new_connection, attr, obj) | ||||
|             return Action.LINK, new_connection | ||||
|         if matching_mode in [mp.deny_mode for mp in matchable_properties]: | ||||
|             self._logger.info("Denying source because object exists", obj=obj) | ||||
|             return Action.DENY, None | ||||
|  | ||||
|         # Should never get here as default enroll case is returned above. | ||||
|         return Action.DENY, None  # pragma: no cover | ||||
|  | ||||
|     def get_user_action( | ||||
|         self, identifier: str, properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> tuple[Action, UserSourceConnection | None]: | ||||
|         return self.get_action( | ||||
|             User, | ||||
|             [ | ||||
|                 MatchableProperty( | ||||
|                     "username", | ||||
|                     SourceUserMatchingModes.USERNAME_LINK, | ||||
|                     SourceUserMatchingModes.USERNAME_DENY, | ||||
|                 ), | ||||
|                 MatchableProperty( | ||||
|                     "email", SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_DENY | ||||
|                 ), | ||||
|             ], | ||||
|             identifier, | ||||
|             properties, | ||||
|         ) | ||||
|  | ||||
|     def get_group_action( | ||||
|         self, identifier: str, properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> tuple[Action, GroupSourceConnection | None]: | ||||
|         return self.get_action( | ||||
|             Group, | ||||
|             [ | ||||
|                 MatchableProperty( | ||||
|                     "name", SourceGroupMatchingModes.NAME_LINK, SourceGroupMatchingModes.NAME_DENY | ||||
|                 ), | ||||
|             ], | ||||
|             identifier, | ||||
|             properties, | ||||
|         ) | ||||
							
								
								
									
										43
									
								
								authentik/core/templates/if/end_session.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								authentik/core/templates/if/end_session.html
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | ||||
| {% extends 'login/base_full.html' %} | ||||
|  | ||||
| {% load static %} | ||||
| {% load i18n %} | ||||
|  | ||||
| {% block title %} | ||||
| {% trans 'End session' %} - {{ brand.branding_title }} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block card_title %} | ||||
| {% blocktrans with application=application.name %} | ||||
| You've logged out of {{ application }}. | ||||
| {% endblocktrans %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block card %} | ||||
| <form method="POST" class="pf-c-form"> | ||||
|     <p> | ||||
|         {% blocktrans with application=application.name branding_title=brand.branding_title %} | ||||
|             You've logged out of {{ application }}. You can go back to the overview to launch another application, or log out of your {{ branding_title }} account. | ||||
|         {% endblocktrans %} | ||||
|     </p> | ||||
|  | ||||
|     <a id="ak-back-home" href="{% url 'authentik_core:root-redirect' %}" class="pf-c-button pf-m-primary"> | ||||
|         {% trans 'Go back to overview' %} | ||||
|     </a> | ||||
|  | ||||
|     <a id="logout" href="{% url 'authentik_flows:default-invalidation' %}" class="pf-c-button pf-m-secondary"> | ||||
|         {% blocktrans with branding_title=brand.branding_title %} | ||||
|             Log out of {{ branding_title }} | ||||
|         {% endblocktrans %} | ||||
|     </a> | ||||
|  | ||||
|     {% if application.get_launch_url %} | ||||
|     <a href="{{ application.get_launch_url }}" class="pf-c-button pf-m-secondary"> | ||||
|         {% blocktrans with application=application.name %} | ||||
|             Log back into {{ application }} | ||||
|         {% endblocktrans %} | ||||
|     </a> | ||||
|     {% endif %} | ||||
|  | ||||
| </form> | ||||
| {% endblock %} | ||||
| @ -134,7 +134,6 @@ class TestApplicationsAPI(APITestCase): | ||||
|                             "assigned_application_name": "allowed", | ||||
|                             "assigned_application_slug": "allowed", | ||||
|                             "authentication_flow": None, | ||||
|                             "invalidation_flow": None, | ||||
|                             "authorization_flow": str(self.provider.authorization_flow.pk), | ||||
|                             "component": "ak-provider-oauth2-form", | ||||
|                             "meta_model_name": "authentik_providers_oauth2.oauth2provider", | ||||
| @ -187,7 +186,6 @@ class TestApplicationsAPI(APITestCase): | ||||
|                             "assigned_application_name": "allowed", | ||||
|                             "assigned_application_slug": "allowed", | ||||
|                             "authentication_flow": None, | ||||
|                             "invalidation_flow": None, | ||||
|                             "authorization_flow": str(self.provider.authorization_flow.pk), | ||||
|                             "component": "ak-provider-oauth2-form", | ||||
|                             "meta_model_name": "authentik_providers_oauth2.oauth2provider", | ||||
|  | ||||
| @ -1,59 +0,0 @@ | ||||
| """Test Devices API""" | ||||
|  | ||||
| from json import loads | ||||
|  | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||
|  | ||||
|  | ||||
| class TestDevicesAPI(APITestCase): | ||||
|     """Test applications API""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.admin = create_test_admin_user() | ||||
|         self.user1 = create_test_user() | ||||
|         self.device1 = self.user1.staticdevice_set.create() | ||||
|         self.user2 = create_test_user() | ||||
|         self.device2 = self.user2.staticdevice_set.create() | ||||
|  | ||||
|     def test_user_api(self): | ||||
|         """Test user API""" | ||||
|         self.client.force_login(self.user1) | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:device-list", | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         body = loads(response.content.decode()) | ||||
|         self.assertEqual(len(body), 1) | ||||
|         self.assertEqual(body[0]["pk"], str(self.device1.pk)) | ||||
|  | ||||
|     def test_user_api_as_admin(self): | ||||
|         """Test user API""" | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:device-list", | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         body = loads(response.content.decode()) | ||||
|         self.assertEqual(len(body), 0) | ||||
|  | ||||
|     def test_admin_api(self): | ||||
|         """Test admin API""" | ||||
|         self.client.force_login(self.admin) | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:admin-device-list", | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         body = loads(response.content.decode()) | ||||
|         self.assertEqual(len(body), 2) | ||||
|         self.assertEqual( | ||||
|             {body[0]["pk"], body[1]["pk"]}, {str(self.device1.pk), str(self.device2.pk)} | ||||
|         ) | ||||
| @ -44,26 +44,6 @@ class TestImpersonation(APITestCase): | ||||
|         self.assertEqual(response_body["user"]["username"], self.user.username) | ||||
|         self.assertNotIn("original", response_body) | ||||
|  | ||||
|     def test_impersonate_global(self): | ||||
|         """Test impersonation with global permissions""" | ||||
|         new_user = create_test_user() | ||||
|         assign_perm("authentik_core.impersonate", new_user) | ||||
|         assign_perm("authentik_core.view_user", new_user) | ||||
|         self.client.force_login(new_user) | ||||
|  | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_api:user-impersonate", | ||||
|                 kwargs={"pk": self.other_user.pk}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|  | ||||
|         response = self.client.get(reverse("authentik_api:user-me")) | ||||
|         response_body = loads(response.content.decode()) | ||||
|         self.assertEqual(response_body["user"]["username"], self.other_user.username) | ||||
|         self.assertEqual(response_body["original"]["username"], new_user.username) | ||||
|  | ||||
|     def test_impersonate_scoped(self): | ||||
|         """Test impersonation with scoped permissions""" | ||||
|         new_user = create_test_user() | ||||
|  | ||||
| @ -19,6 +19,7 @@ class TestTransactionalApplicationsAPI(APITestCase): | ||||
|         """Test transactional Application + provider creation""" | ||||
|         self.client.force_login(self.user) | ||||
|         uid = generate_id() | ||||
|         authorization_flow = create_test_flow() | ||||
|         response = self.client.put( | ||||
|             reverse("authentik_api:core-transactional-application"), | ||||
|             data={ | ||||
| @ -29,8 +30,7 @@ class TestTransactionalApplicationsAPI(APITestCase): | ||||
|                 "provider_model": "authentik_providers_oauth2.oauth2provider", | ||||
|                 "provider": { | ||||
|                     "name": uid, | ||||
|                     "authorization_flow": str(create_test_flow().pk), | ||||
|                     "invalidation_flow": str(create_test_flow().pk), | ||||
|                     "authorization_flow": str(authorization_flow.pk), | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
| @ -56,16 +56,10 @@ class TestTransactionalApplicationsAPI(APITestCase): | ||||
|                 "provider": { | ||||
|                     "name": uid, | ||||
|                     "authorization_flow": "", | ||||
|                     "invalidation_flow": "", | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "provider": { | ||||
|                     "authorization_flow": ["This field may not be null."], | ||||
|                     "invalidation_flow": ["This field may not be null."], | ||||
|                 } | ||||
|             }, | ||||
|             {"provider": {"authorization_flow": ["This field may not be null."]}}, | ||||
|         ) | ||||
|  | ||||
| @ -5,6 +5,7 @@ from channels.sessions import CookieMiddleware | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.decorators import login_required | ||||
| from django.urls import path | ||||
| from django.views.decorators.csrf import ensure_csrf_cookie | ||||
|  | ||||
| from authentik.core.api.applications import ApplicationViewSet | ||||
| from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet | ||||
| @ -23,6 +24,7 @@ from authentik.core.views.interface import ( | ||||
|     InterfaceView, | ||||
|     RootRedirectView, | ||||
| ) | ||||
| from authentik.core.views.session import EndSessionView | ||||
| from authentik.flows.views.interface import FlowInterfaceView | ||||
| from authentik.root.asgi_middleware import SessionMiddleware | ||||
| from authentik.root.messages.consumer import MessageConsumer | ||||
| @ -43,21 +45,26 @@ urlpatterns = [ | ||||
|     # Interfaces | ||||
|     path( | ||||
|         "if/admin/", | ||||
|         BrandDefaultRedirectView.as_view(template_name="if/admin.html"), | ||||
|         ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/admin.html")), | ||||
|         name="if-admin", | ||||
|     ), | ||||
|     path( | ||||
|         "if/user/", | ||||
|         BrandDefaultRedirectView.as_view(template_name="if/user.html"), | ||||
|         ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/user.html")), | ||||
|         name="if-user", | ||||
|     ), | ||||
|     path( | ||||
|         "if/flow/<slug:flow_slug>/", | ||||
|         # FIXME: move this url to the flows app...also will cause all | ||||
|         # of the reverse calls to be adjusted | ||||
|         FlowInterfaceView.as_view(), | ||||
|         ensure_csrf_cookie(FlowInterfaceView.as_view()), | ||||
|         name="if-flow", | ||||
|     ), | ||||
|     path( | ||||
|         "if/session-end/<slug:application_slug>/", | ||||
|         ensure_csrf_cookie(EndSessionView.as_view()), | ||||
|         name="if-session-end", | ||||
|     ), | ||||
|     # Fallback for WS | ||||
|     path("ws/outpost/<uuid:pk>/", InterfaceView.as_view(template_name="if/admin.html")), | ||||
|     path( | ||||
|  | ||||
							
								
								
									
										23
									
								
								authentik/core/views/session.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								authentik/core/views/session.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | ||||
| """authentik Session Views""" | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from django.shortcuts import get_object_or_404 | ||||
| from django.views.generic.base import TemplateView | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.policies.views import PolicyAccessView | ||||
|  | ||||
|  | ||||
| class EndSessionView(TemplateView, PolicyAccessView): | ||||
|     """Allow the client to end the Session""" | ||||
|  | ||||
|     template_name = "if/end_session.html" | ||||
|  | ||||
|     def resolve_provider_application(self): | ||||
|         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) | ||||
|  | ||||
|     def get_context_data(self, **kwargs: Any) -> dict[str, Any]: | ||||
|         context = super().get_context_data(**kwargs) | ||||
|         context["application"] = self.application | ||||
|         return context | ||||
| @ -68,7 +68,6 @@ class TestEndpointsAPI(APITestCase): | ||||
|                             "name": self.provider.name, | ||||
|                             "authentication_flow": None, | ||||
|                             "authorization_flow": None, | ||||
|                             "invalidation_flow": None, | ||||
|                             "property_mappings": [], | ||||
|                             "connection_expiry": "hours=8", | ||||
|                             "delete_token_on_disconnect": False, | ||||
| @ -121,7 +120,6 @@ class TestEndpointsAPI(APITestCase): | ||||
|                             "name": self.provider.name, | ||||
|                             "authentication_flow": None, | ||||
|                             "authorization_flow": None, | ||||
|                             "invalidation_flow": None, | ||||
|                             "property_mappings": [], | ||||
|                             "component": "ak-provider-rac-form", | ||||
|                             "assigned_application_slug": self.app.slug, | ||||
| @ -151,7 +149,6 @@ class TestEndpointsAPI(APITestCase): | ||||
|                             "name": self.provider.name, | ||||
|                             "authentication_flow": None, | ||||
|                             "authorization_flow": None, | ||||
|                             "invalidation_flow": None, | ||||
|                             "property_mappings": [], | ||||
|                             "component": "ak-provider-rac-form", | ||||
|                             "assigned_application_slug": self.app.slug, | ||||
|  | ||||
| @ -3,6 +3,7 @@ | ||||
| from channels.auth import AuthMiddleware | ||||
| from channels.sessions import CookieMiddleware | ||||
| from django.urls import path | ||||
| from django.views.decorators.csrf import ensure_csrf_cookie | ||||
|  | ||||
| from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet | ||||
| from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet | ||||
| @ -18,12 +19,12 @@ from authentik.root.middleware import ChannelsLoggingMiddleware | ||||
| urlpatterns = [ | ||||
|     path( | ||||
|         "application/rac/<slug:app>/<uuid:endpoint>/", | ||||
|         RACStartView.as_view(), | ||||
|         ensure_csrf_cookie(RACStartView.as_view()), | ||||
|         name="start", | ||||
|     ), | ||||
|     path( | ||||
|         "if/rac/<str:token>/", | ||||
|         RACInterface.as_view(), | ||||
|         ensure_csrf_cookie(RACInterface.as_view()), | ||||
|         name="if-rac", | ||||
|     ), | ||||
| ] | ||||
|  | ||||
| @ -17,7 +17,6 @@ TENANT_APPS = [ | ||||
|     "authentik.enterprise.providers.google_workspace", | ||||
|     "authentik.enterprise.providers.microsoft_entra", | ||||
|     "authentik.enterprise.providers.rac", | ||||
|     "authentik.enterprise.stages.authenticator_endpoint_gdtc", | ||||
|     "authentik.enterprise.stages.source", | ||||
| ] | ||||
|  | ||||
|  | ||||
| @ -1,82 +0,0 @@ | ||||
| """AuthenticatorEndpointGDTCStage API Views""" | ||||
|  | ||||
| from django_filters.rest_framework.backends import DjangoFilterBackend | ||||
| from rest_framework import mixins | ||||
| from rest_framework.filters import OrderingFilter, SearchFilter | ||||
| from rest_framework.permissions import IsAdminUser | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import GenericViewSet, ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.api.authorization import OwnerFilter, OwnerPermissions | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.enterprise.api import EnterpriseRequiredMixin | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | ||||
|     AuthenticatorEndpointGDTCStage, | ||||
|     EndpointDevice, | ||||
| ) | ||||
| from authentik.flows.api.stages import StageSerializer | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class AuthenticatorEndpointGDTCStageSerializer(EnterpriseRequiredMixin, StageSerializer): | ||||
|     """AuthenticatorEndpointGDTCStage Serializer""" | ||||
|  | ||||
|     class Meta: | ||||
|         model = AuthenticatorEndpointGDTCStage | ||||
|         fields = StageSerializer.Meta.fields + [ | ||||
|             "configure_flow", | ||||
|             "friendly_name", | ||||
|             "credentials", | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class AuthenticatorEndpointGDTCStageViewSet(UsedByMixin, ModelViewSet): | ||||
|     """AuthenticatorEndpointGDTCStage Viewset""" | ||||
|  | ||||
|     queryset = AuthenticatorEndpointGDTCStage.objects.all() | ||||
|     serializer_class = AuthenticatorEndpointGDTCStageSerializer | ||||
|     filterset_fields = [ | ||||
|         "name", | ||||
|         "configure_flow", | ||||
|     ] | ||||
|     search_fields = ["name"] | ||||
|     ordering = ["name"] | ||||
|  | ||||
|  | ||||
| class EndpointDeviceSerializer(ModelSerializer): | ||||
|     """Serializer for Endpoint authenticator devices""" | ||||
|  | ||||
|     class Meta: | ||||
|         model = EndpointDevice | ||||
|         fields = ["pk", "name"] | ||||
|         depth = 2 | ||||
|  | ||||
|  | ||||
| class EndpointDeviceViewSet( | ||||
|     mixins.RetrieveModelMixin, | ||||
|     mixins.ListModelMixin, | ||||
|     UsedByMixin, | ||||
|     GenericViewSet, | ||||
| ): | ||||
|     """Viewset for Endpoint authenticator devices""" | ||||
|  | ||||
|     queryset = EndpointDevice.objects.all() | ||||
|     serializer_class = EndpointDeviceSerializer | ||||
|     search_fields = ["name"] | ||||
|     filterset_fields = ["name"] | ||||
|     ordering = ["name"] | ||||
|     permission_classes = [OwnerPermissions] | ||||
|     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] | ||||
|  | ||||
|  | ||||
| class EndpointAdminDeviceViewSet(ModelViewSet): | ||||
|     """Viewset for Endpoint authenticator devices (for admins)""" | ||||
|  | ||||
|     permission_classes = [IsAdminUser] | ||||
|     queryset = EndpointDevice.objects.all() | ||||
|     serializer_class = EndpointDeviceSerializer | ||||
|     search_fields = ["name"] | ||||
|     filterset_fields = ["name"] | ||||
|     ordering = ["name"] | ||||
| @ -1,13 +0,0 @@ | ||||
| """authentik Endpoint app config""" | ||||
|  | ||||
| from authentik.enterprise.apps import EnterpriseConfig | ||||
|  | ||||
|  | ||||
| class AuthentikStageAuthenticatorEndpointConfig(EnterpriseConfig): | ||||
|     """authentik endpoint config""" | ||||
|  | ||||
|     name = "authentik.enterprise.stages.authenticator_endpoint_gdtc" | ||||
|     label = "authentik_stages_authenticator_endpoint_gdtc" | ||||
|     verbose_name = "authentik Enterprise.Stages.Authenticator.Endpoint GDTC" | ||||
|     default = True | ||||
|     mountpoint = "endpoint/gdtc/" | ||||
| @ -1,115 +0,0 @@ | ||||
| # Generated by Django 5.0.9 on 2024-10-22 11:40 | ||||
|  | ||||
| import django.db.models.deletion | ||||
| import uuid | ||||
| from django.conf import settings | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     initial = True | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_flows", "0027_auto_20231028_1424"), | ||||
|         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.CreateModel( | ||||
|             name="AuthenticatorEndpointGDTCStage", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "stage_ptr", | ||||
|                     models.OneToOneField( | ||||
|                         auto_created=True, | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         parent_link=True, | ||||
|                         primary_key=True, | ||||
|                         serialize=False, | ||||
|                         to="authentik_flows.stage", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("friendly_name", models.TextField(null=True)), | ||||
|                 ("credentials", models.JSONField()), | ||||
|                 ( | ||||
|                     "configure_flow", | ||||
|                     models.ForeignKey( | ||||
|                         blank=True, | ||||
|                         help_text="Flow used by an authenticated user to configure this Stage. If empty, user will not be able to configure this stage.", | ||||
|                         null=True, | ||||
|                         on_delete=django.db.models.deletion.SET_NULL, | ||||
|                         to="authentik_flows.flow", | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|             options={ | ||||
|                 "verbose_name": "Endpoint Authenticator Google Device Trust Connector Stage", | ||||
|                 "verbose_name_plural": "Endpoint Authenticator Google Device Trust Connector Stages", | ||||
|             }, | ||||
|             bases=("authentik_flows.stage", models.Model), | ||||
|         ), | ||||
|         migrations.CreateModel( | ||||
|             name="EndpointDevice", | ||||
|             fields=[ | ||||
|                 ("created", models.DateTimeField(auto_now_add=True)), | ||||
|                 ("last_updated", models.DateTimeField(auto_now=True)), | ||||
|                 ( | ||||
|                     "name", | ||||
|                     models.CharField( | ||||
|                         help_text="The human-readable name of this device.", max_length=64 | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "confirmed", | ||||
|                     models.BooleanField(default=True, help_text="Is this device ready for use?"), | ||||
|                 ), | ||||
|                 ("last_used", models.DateTimeField(null=True)), | ||||
|                 ("uuid", models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), | ||||
|                 ( | ||||
|                     "host_identifier", | ||||
|                     models.TextField( | ||||
|                         help_text="A unique identifier for the endpoint device, usually the device serial number", | ||||
|                         unique=True, | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("data", models.JSONField()), | ||||
|                 ( | ||||
|                     "user", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|             options={ | ||||
|                 "verbose_name": "Endpoint Device", | ||||
|                 "verbose_name_plural": "Endpoint Devices", | ||||
|             }, | ||||
|         ), | ||||
|         migrations.CreateModel( | ||||
|             name="EndpointDeviceConnection", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "id", | ||||
|                     models.AutoField( | ||||
|                         auto_created=True, primary_key=True, serialize=False, verbose_name="ID" | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("attributes", models.JSONField()), | ||||
|                 ( | ||||
|                     "device", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_stages_authenticator_endpoint_gdtc.endpointdevice", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "stage", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_stages_authenticator_endpoint_gdtc.authenticatorendpointgdtcstage", | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|         ), | ||||
|     ] | ||||
| @ -1,101 +0,0 @@ | ||||
| """Endpoint stage""" | ||||
|  | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.contrib.auth import get_user_model | ||||
| from django.db import models | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from google.oauth2.service_account import Credentials | ||||
| from rest_framework.serializers import BaseSerializer, Serializer | ||||
|  | ||||
| from authentik.core.types import UserSettingSerializer | ||||
| from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage | ||||
| from authentik.flows.stage import StageView | ||||
| from authentik.lib.models import SerializerModel | ||||
| from authentik.stages.authenticator.models import Device | ||||
|  | ||||
|  | ||||
| class AuthenticatorEndpointGDTCStage(ConfigurableStage, FriendlyNamedStage, Stage): | ||||
|     """Setup Google Chrome Device-trust connection""" | ||||
|  | ||||
|     credentials = models.JSONField() | ||||
|  | ||||
|     def google_credentials(self): | ||||
|         return { | ||||
|             "credentials": Credentials.from_service_account_info( | ||||
|                 self.credentials, scopes=["https://www.googleapis.com/auth/verifiedaccess"] | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[BaseSerializer]: | ||||
|         from authentik.enterprise.stages.authenticator_endpoint_gdtc.api import ( | ||||
|             AuthenticatorEndpointGDTCStageSerializer, | ||||
|         ) | ||||
|  | ||||
|         return AuthenticatorEndpointGDTCStageSerializer | ||||
|  | ||||
|     @property | ||||
|     def view(self) -> type[StageView]: | ||||
|         from authentik.enterprise.stages.authenticator_endpoint_gdtc.stage import ( | ||||
|             AuthenticatorEndpointStageView, | ||||
|         ) | ||||
|  | ||||
|         return AuthenticatorEndpointStageView | ||||
|  | ||||
|     @property | ||||
|     def component(self) -> str: | ||||
|         return "ak-stage-authenticator-endpoint-gdtc-form" | ||||
|  | ||||
|     def ui_user_settings(self) -> UserSettingSerializer | None: | ||||
|         return UserSettingSerializer( | ||||
|             data={ | ||||
|                 "title": self.friendly_name or str(self._meta.verbose_name), | ||||
|                 "component": "ak-user-settings-authenticator-endpoint", | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"Endpoint Authenticator Google Device Trust Connector Stage {self.name}" | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Endpoint Authenticator Google Device Trust Connector Stage") | ||||
|         verbose_name_plural = _("Endpoint Authenticator Google Device Trust Connector Stages") | ||||
|  | ||||
|  | ||||
| class EndpointDevice(SerializerModel, Device): | ||||
|     """Endpoint Device for a single user""" | ||||
|  | ||||
|     uuid = models.UUIDField(primary_key=True, default=uuid4) | ||||
|     host_identifier = models.TextField( | ||||
|         unique=True, | ||||
|         help_text="A unique identifier for the endpoint device, usually the device serial number", | ||||
|     ) | ||||
|  | ||||
|     user = models.ForeignKey(get_user_model(), on_delete=models.CASCADE) | ||||
|     data = models.JSONField() | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> Serializer: | ||||
|         from authentik.enterprise.stages.authenticator_endpoint_gdtc.api import ( | ||||
|             EndpointDeviceSerializer, | ||||
|         ) | ||||
|  | ||||
|         return EndpointDeviceSerializer | ||||
|  | ||||
|     def __str__(self): | ||||
|         return str(self.name) or str(self.user_id) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Endpoint Device") | ||||
|         verbose_name_plural = _("Endpoint Devices") | ||||
|  | ||||
|  | ||||
| class EndpointDeviceConnection(models.Model): | ||||
|     device = models.ForeignKey(EndpointDevice, on_delete=models.CASCADE) | ||||
|     stage = models.ForeignKey(AuthenticatorEndpointGDTCStage, on_delete=models.CASCADE) | ||||
|  | ||||
|     attributes = models.JSONField() | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"Endpoint device connection {self.device_id} to {self.stage_id}" | ||||
| @ -1,32 +0,0 @@ | ||||
| from django.http import HttpResponse | ||||
| from django.urls import reverse | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| from authentik.flows.challenge import ( | ||||
|     Challenge, | ||||
|     ChallengeResponse, | ||||
|     FrameChallenge, | ||||
|     FrameChallengeResponse, | ||||
| ) | ||||
| from authentik.flows.stage import ChallengeStageView | ||||
|  | ||||
|  | ||||
| class AuthenticatorEndpointStageView(ChallengeStageView): | ||||
|     """Endpoint stage""" | ||||
|  | ||||
|     response_class = FrameChallengeResponse | ||||
|  | ||||
|     def get_challenge(self, *args, **kwargs) -> Challenge: | ||||
|         return FrameChallenge( | ||||
|             data={ | ||||
|                 "component": "xak-flow-frame", | ||||
|                 "url": self.request.build_absolute_uri( | ||||
|                     reverse("authentik_stages_authenticator_endpoint_gdtc:chrome") | ||||
|                 ), | ||||
|                 "loading_overlay": True, | ||||
|                 "loading_text": _("Verifying your browser..."), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||
|         return self.executor.stage_ok() | ||||
| @ -1,9 +0,0 @@ | ||||
| <html> | ||||
| <script> | ||||
|   window.parent.postMessage({ | ||||
|     message: "submit", | ||||
|     source: "goauthentik.io", | ||||
|     context: "flow-executor" | ||||
|   }); | ||||
| </script> | ||||
| </html> | ||||
| @ -1,26 +0,0 @@ | ||||
| """API URLs""" | ||||
|  | ||||
| from django.urls import path | ||||
|  | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.api import ( | ||||
|     AuthenticatorEndpointGDTCStageViewSet, | ||||
|     EndpointAdminDeviceViewSet, | ||||
|     EndpointDeviceViewSet, | ||||
| ) | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.views.dtc import ( | ||||
|     GoogleChromeDeviceTrustConnector, | ||||
| ) | ||||
|  | ||||
| urlpatterns = [ | ||||
|     path("chrome/", GoogleChromeDeviceTrustConnector.as_view(), name="chrome"), | ||||
| ] | ||||
|  | ||||
| api_urlpatterns = [ | ||||
|     ("authenticators/endpoint", EndpointDeviceViewSet), | ||||
|     ( | ||||
|         "authenticators/admin/endpoint", | ||||
|         EndpointAdminDeviceViewSet, | ||||
|         "admin-endpointdevice", | ||||
|     ), | ||||
|     ("stages/authenticator/endpoint_gdtc", AuthenticatorEndpointGDTCStageViewSet), | ||||
| ] | ||||
| @ -1,84 +0,0 @@ | ||||
| from json import dumps, loads | ||||
| from typing import Any | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse, HttpResponseRedirect | ||||
| from django.template.response import TemplateResponse | ||||
| from django.urls import reverse | ||||
| from django.views import View | ||||
| from googleapiclient.discovery import build | ||||
|  | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | ||||
|     AuthenticatorEndpointGDTCStage, | ||||
|     EndpointDevice, | ||||
|     EndpointDeviceConnection, | ||||
| ) | ||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan | ||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
| from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS | ||||
|  | ||||
| # Header we get from chrome that initiates verified access | ||||
| HEADER_DEVICE_TRUST = "X-Device-Trust" | ||||
| # Header we send to the client with the challenge | ||||
| HEADER_ACCESS_CHALLENGE = "X-Verified-Access-Challenge" | ||||
| # Header we get back from the client that we verify with google | ||||
| HEADER_ACCESS_CHALLENGE_RESPONSE = "X-Verified-Access-Challenge-Response" | ||||
| # Header value for x-device-trust that initiates the flow | ||||
| DEVICE_TRUST_VERIFIED_ACCESS = "VerifiedAccess" | ||||
|  | ||||
|  | ||||
| class GoogleChromeDeviceTrustConnector(View): | ||||
|     """Google Chrome Device-trust connector based endpoint authenticator""" | ||||
|  | ||||
|     def get_flow_plan(self) -> FlowPlan: | ||||
|         flow_plan: FlowPlan = self.request.session[SESSION_KEY_PLAN] | ||||
|         return flow_plan | ||||
|  | ||||
|     def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: | ||||
|         super().setup(request, *args, **kwargs) | ||||
|         stage: AuthenticatorEndpointGDTCStage = self.get_flow_plan().bindings[0].stage | ||||
|         self.google_client = build( | ||||
|             "verifiedaccess", | ||||
|             "v2", | ||||
|             cache_discovery=False, | ||||
|             **stage.google_credentials(), | ||||
|         ) | ||||
|  | ||||
|     def get(self, request: HttpRequest) -> HttpResponse: | ||||
|         x_device_trust = request.headers.get(HEADER_DEVICE_TRUST) | ||||
|         x_access_challenge_response = request.headers.get(HEADER_ACCESS_CHALLENGE_RESPONSE) | ||||
|         if x_device_trust == "VerifiedAccess" and x_access_challenge_response is None: | ||||
|             challenge = self.google_client.challenge().generate().execute() | ||||
|             res = HttpResponseRedirect( | ||||
|                 self.request.build_absolute_uri( | ||||
|                     reverse("authentik_stages_authenticator_endpoint_gdtc:chrome") | ||||
|                 ) | ||||
|             ) | ||||
|             res[HEADER_ACCESS_CHALLENGE] = dumps(challenge) | ||||
|             return res | ||||
|         if x_access_challenge_response: | ||||
|             response = ( | ||||
|                 self.google_client.challenge() | ||||
|                 .verify(body=loads(x_access_challenge_response)) | ||||
|                 .execute() | ||||
|             ) | ||||
|             # Remove deprecated string representation of deviceSignals | ||||
|             response.pop("deviceSignal", None) | ||||
|             flow_plan: FlowPlan = self.get_flow_plan() | ||||
|             device, _ = EndpointDevice.objects.update_or_create( | ||||
|                 host_identifier=response["deviceSignals"]["serialNumber"], | ||||
|                 user=flow_plan.context.get(PLAN_CONTEXT_PENDING_USER), | ||||
|                 defaults={"name": response["deviceSignals"]["hostname"], "data": response}, | ||||
|             ) | ||||
|             EndpointDeviceConnection.objects.update_or_create( | ||||
|                 device=device, | ||||
|                 stage=flow_plan.bindings[0].stage, | ||||
|                 defaults={ | ||||
|                     "attributes": response, | ||||
|                 }, | ||||
|             ) | ||||
|             flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, "trusted_endpoint") | ||||
|             flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {}) | ||||
|             flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault("endpoints", []) | ||||
|             flow_plan.context[PLAN_CONTEXT_METHOD_ARGS]["endpoints"].append(response) | ||||
|             request.session[SESSION_KEY_PLAN] = flow_plan | ||||
|         return TemplateResponse(request, "stages/authenticator_endpoint/google_chrome_dtc.html") | ||||
| @ -50,7 +50,7 @@ class ASNContextProcessor(MMDBContextProcessor): | ||||
|         """Wrapper for Reader.asn""" | ||||
|         with start_span( | ||||
|             op="authentik.events.asn.asn", | ||||
|             name=ip_address, | ||||
|             description=ip_address, | ||||
|         ): | ||||
|             if not self.configured(): | ||||
|                 return None | ||||
|  | ||||
| @ -51,7 +51,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): | ||||
|         """Wrapper for Reader.city""" | ||||
|         with start_span( | ||||
|             op="authentik.events.geo.city", | ||||
|             name=ip_address, | ||||
|             description=ip_address, | ||||
|         ): | ||||
|             if not self.configured(): | ||||
|                 return None | ||||
|  | ||||
| @ -1,16 +1,13 @@ | ||||
| """authentik events signal listener""" | ||||
|  | ||||
| from importlib import import_module | ||||
| from typing import Any | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.signals import user_logged_in, user_logged_out | ||||
| from django.db.models.signals import post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
| from rest_framework.request import Request | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.core.models import User | ||||
| from authentik.core.signals import login_failed, password_changed | ||||
| from authentik.events.apps import SYSTEM_TASK_STATUS | ||||
| from authentik.events.models import Event, EventAction, SystemTask | ||||
| @ -26,7 +23,6 @@ from authentik.stages.user_write.signals import user_write | ||||
| from authentik.tenants.utils import get_current_tenant | ||||
|  | ||||
| SESSION_LOGIN_EVENT = "login_event" | ||||
| _session_engine = import_module(settings.SESSION_ENGINE) | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_in) | ||||
| @ -47,20 +43,11 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): | ||||
|             kwargs[PLAN_CONTEXT_OUTPOST] = flow_plan.context[PLAN_CONTEXT_OUTPOST] | ||||
|     event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user) | ||||
|     request.session[SESSION_LOGIN_EVENT] = event | ||||
|     request.session.save() | ||||
|  | ||||
|  | ||||
| def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | None) -> Event | None: | ||||
| def get_login_event(request: HttpRequest) -> Event | None: | ||||
|     """Wrapper to get login event that can be mocked in tests""" | ||||
|     session = None | ||||
|     if not request_or_session: | ||||
|         return None | ||||
|     if isinstance(request_or_session, HttpRequest | Request): | ||||
|         session = request_or_session.session | ||||
|     if isinstance(request_or_session, AuthenticatedSession): | ||||
|         SessionStore = _session_engine.SessionStore | ||||
|         session = SessionStore(request_or_session.session_key) | ||||
|     return session.get(SESSION_LOGIN_EVENT, None) | ||||
|     return request.session.get(SESSION_LOGIN_EVENT, None) | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
|  | ||||
| @ -8,7 +8,7 @@ from uuid import UUID | ||||
| from django.core.serializers.json import DjangoJSONEncoder | ||||
| from django.db import models | ||||
| from django.http import JsonResponse | ||||
| from rest_framework.fields import BooleanField, CharField, ChoiceField, DictField | ||||
| from rest_framework.fields import CharField, ChoiceField, DictField | ||||
| from rest_framework.request import Request | ||||
|  | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| @ -110,21 +110,8 @@ class FlowErrorChallenge(Challenge): | ||||
| class AccessDeniedChallenge(WithUserInfoChallenge): | ||||
|     """Challenge when a flow's active stage calls `stage_invalid()`.""" | ||||
|  | ||||
|     component = CharField(default="ak-stage-access-denied") | ||||
|  | ||||
|     error_message = CharField(required=False) | ||||
|  | ||||
|  | ||||
| class SessionEndChallenge(WithUserInfoChallenge): | ||||
|     """Challenge for ending a session""" | ||||
|  | ||||
|     component = CharField(default="ak-stage-session-end") | ||||
|  | ||||
|     application_name = CharField(required=False) | ||||
|     application_launch_url = CharField(required=False) | ||||
|  | ||||
|     invalidation_flow_url = CharField(required=False) | ||||
|     brand_name = CharField(required=True) | ||||
|     component = CharField(default="ak-stage-access-denied") | ||||
|  | ||||
|  | ||||
| class PermissionDict(TypedDict): | ||||
| @ -160,20 +147,6 @@ class AutoSubmitChallengeResponse(ChallengeResponse): | ||||
|     component = CharField(default="ak-stage-autosubmit") | ||||
|  | ||||
|  | ||||
| class FrameChallenge(Challenge): | ||||
|     """Challenge type to render a frame""" | ||||
|  | ||||
|     component = CharField(default="xak-flow-frame") | ||||
|     url = CharField() | ||||
|     loading_overlay = BooleanField(default=False) | ||||
|     loading_text = CharField() | ||||
|  | ||||
|  | ||||
| class FrameChallengeResponse(ChallengeResponse): | ||||
|  | ||||
|     component = CharField(default="xak-flow-frame") | ||||
|  | ||||
|  | ||||
| class DataclassEncoder(DjangoJSONEncoder): | ||||
|     """Convert any dataclass to json""" | ||||
|  | ||||
|  | ||||
| @ -6,18 +6,20 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
|  | ||||
| def set_oobe_flow_authentication(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     from guardian.conf import settings as guardian_settings | ||||
|     from guardian.shortcuts import get_anonymous_user | ||||
|  | ||||
|     Flow = apps.get_model("authentik_flows", "Flow") | ||||
|     User = apps.get_model("authentik_core", "User") | ||||
|  | ||||
|     db_alias = schema_editor.connection.alias | ||||
|  | ||||
|     users = ( | ||||
|         User.objects.using(db_alias) | ||||
|         .exclude(username="akadmin") | ||||
|         .exclude(username=guardian_settings.ANONYMOUS_USER_NAME) | ||||
|     ) | ||||
|     users = User.objects.using(db_alias).exclude(username="akadmin") | ||||
|     try: | ||||
|         users = users.exclude(pk=get_anonymous_user().pk) | ||||
|  | ||||
|     except Exception:  # nosec | ||||
|         pass | ||||
|  | ||||
|     if users.exists(): | ||||
|         Flow.objects.using(db_alias).filter(slug="initial-setup").update( | ||||
|             authentication="require_superuser" | ||||
|  | ||||
| @ -107,9 +107,7 @@ class Stage(SerializerModel): | ||||
|  | ||||
|  | ||||
| def in_memory_stage(view: type["StageView"], **kwargs) -> Stage: | ||||
|     """Creates an in-memory stage instance, based on a `view` as view. | ||||
|     Any key-word arguments are set as attributes on the stage object, | ||||
|     accessible via `self.executor.current_stage`.""" | ||||
|     """Creates an in-memory stage instance, based on a `view` as view.""" | ||||
|     stage = Stage() | ||||
|     # Because we can't pickle a locally generated function, | ||||
|     # we set the view as a separate property and reference a generic function | ||||
|  | ||||
| @ -166,7 +166,7 @@ class FlowPlanner: | ||||
|     def plan(self, request: HttpRequest, default_context: dict[str, Any] | None = None) -> FlowPlan: | ||||
|         """Check each of the flows' policies, check policies for each stage with PolicyBinding | ||||
|         and return ordered list""" | ||||
|         with start_span(op="authentik.flow.planner.plan", name=self.flow.slug) as span: | ||||
|         with start_span(op="authentik.flow.planner.plan", description=self.flow.slug) as span: | ||||
|             span: Span | ||||
|             span.set_data("flow", self.flow) | ||||
|             span.set_data("request", request) | ||||
| @ -233,7 +233,7 @@ class FlowPlanner: | ||||
|         with ( | ||||
|             start_span( | ||||
|                 op="authentik.flow.planner.build_plan", | ||||
|                 name=self.flow.slug, | ||||
|                 description=self.flow.slug, | ||||
|             ) as span, | ||||
|             HIST_FLOWS_PLAN_TIME.labels(flow_slug=self.flow.slug).time(), | ||||
|         ): | ||||
|  | ||||
| @ -13,7 +13,7 @@ from rest_framework.request import Request | ||||
| from sentry_sdk import start_span | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| from authentik.core.models import Application, User | ||||
| from authentik.core.models import User | ||||
| from authentik.flows.challenge import ( | ||||
|     AccessDeniedChallenge, | ||||
|     Challenge, | ||||
| @ -21,7 +21,6 @@ from authentik.flows.challenge import ( | ||||
|     ContextualFlowInfo, | ||||
|     HttpChallengeResponse, | ||||
|     RedirectChallenge, | ||||
|     SessionEndChallenge, | ||||
|     WithUserInfoChallenge, | ||||
| ) | ||||
| from authentik.flows.exceptions import StageInvalidException | ||||
| @ -126,7 +125,7 @@ class ChallengeStageView(StageView): | ||||
|             with ( | ||||
|                 start_span( | ||||
|                     op="authentik.flow.stage.challenge_invalid", | ||||
|                     name=self.__class__.__name__, | ||||
|                     description=self.__class__.__name__, | ||||
|                 ), | ||||
|                 HIST_FLOWS_STAGE_TIME.labels( | ||||
|                     stage_type=self.__class__.__name__, method="challenge_invalid" | ||||
| @ -136,7 +135,7 @@ class ChallengeStageView(StageView): | ||||
|         with ( | ||||
|             start_span( | ||||
|                 op="authentik.flow.stage.challenge_valid", | ||||
|                 name=self.__class__.__name__, | ||||
|                 description=self.__class__.__name__, | ||||
|             ), | ||||
|             HIST_FLOWS_STAGE_TIME.labels( | ||||
|                 stage_type=self.__class__.__name__, method="challenge_valid" | ||||
| @ -162,7 +161,7 @@ class ChallengeStageView(StageView): | ||||
|         with ( | ||||
|             start_span( | ||||
|                 op="authentik.flow.stage.get_challenge", | ||||
|                 name=self.__class__.__name__, | ||||
|                 description=self.__class__.__name__, | ||||
|             ), | ||||
|             HIST_FLOWS_STAGE_TIME.labels( | ||||
|                 stage_type=self.__class__.__name__, method="get_challenge" | ||||
| @ -175,7 +174,7 @@ class ChallengeStageView(StageView): | ||||
|                 return self.executor.stage_invalid() | ||||
|         with start_span( | ||||
|             op="authentik.flow.stage._get_challenge", | ||||
|             name=self.__class__.__name__, | ||||
|             description=self.__class__.__name__, | ||||
|         ): | ||||
|             if not hasattr(challenge, "initial_data"): | ||||
|                 challenge.initial_data = {} | ||||
| @ -231,7 +230,7 @@ class ChallengeStageView(StageView): | ||||
|         return HttpChallengeResponse(challenge_response) | ||||
|  | ||||
|  | ||||
| class AccessDeniedStage(ChallengeStageView): | ||||
| class AccessDeniedChallengeView(ChallengeStageView): | ||||
|     """Used internally by FlowExecutor's stage_invalid()""" | ||||
|  | ||||
|     error_message: str | None | ||||
| @ -269,31 +268,3 @@ class RedirectStage(ChallengeStageView): | ||||
|  | ||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||
|         return HttpChallengeResponse(self.get_challenge()) | ||||
|  | ||||
|  | ||||
| class SessionEndStage(ChallengeStageView): | ||||
|     """Stage inserted when a flow is used as invalidation flow. By default shows actions | ||||
|     that the user is likely to take after signing out of a provider.""" | ||||
|  | ||||
|     def get_challenge(self, *args, **kwargs) -> Challenge: | ||||
|         application: Application | None = self.executor.plan.context.get(PLAN_CONTEXT_APPLICATION) | ||||
|         data = { | ||||
|             "component": "ak-stage-session-end", | ||||
|             "brand_name": self.request.brand.branding_title, | ||||
|         } | ||||
|         if application: | ||||
|             data["application_name"] = application.name | ||||
|             data["application_launch_url"] = application.get_launch_url(self.get_pending_user()) | ||||
|         if self.request.brand.flow_invalidation: | ||||
|             data["invalidation_flow_url"] = reverse( | ||||
|                 "authentik_core:if-flow", | ||||
|                 kwargs={ | ||||
|                     "flow_slug": self.request.brand.flow_invalidation.slug, | ||||
|                 }, | ||||
|             ) | ||||
|         return SessionEndChallenge(data=data) | ||||
|  | ||||
|     # This can never be reached since this challenge is created on demand and only the | ||||
|     # .get() method is called | ||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:  # pragma: no cover | ||||
|         return self.executor.cancel() | ||||
|  | ||||
| @ -46,7 +46,6 @@ class TestFlowInspector(APITestCase): | ||||
|             res.content, | ||||
|             { | ||||
|                 "allow_show_password": False, | ||||
|                 "captcha_stage": None, | ||||
|                 "component": "ak-stage-identification", | ||||
|                 "flow_info": { | ||||
|                     "background": flow.background_url, | ||||
|  | ||||
| @ -54,7 +54,7 @@ from authentik.flows.planner import ( | ||||
|     FlowPlan, | ||||
|     FlowPlanner, | ||||
| ) | ||||
| from authentik.flows.stage import AccessDeniedStage, StageView | ||||
| from authentik.flows.stage import AccessDeniedChallengeView, StageView | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
| from authentik.lib.utils.reflection import all_subclasses, class_to_path | ||||
| @ -153,7 +153,7 @@ class FlowExecutorView(APIView): | ||||
|         return plan | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: | ||||
|         with start_span(op="authentik.flow.executor.dispatch", name=self.flow.slug) as span: | ||||
|         with start_span(op="authentik.flow.executor.dispatch", description=self.flow.slug) as span: | ||||
|             span.set_data("authentik Flow", self.flow.slug) | ||||
|             get_params = QueryDict(request.GET.get(QS_QUERY, "")) | ||||
|             if QS_KEY_TOKEN in get_params: | ||||
| @ -273,7 +273,7 @@ class FlowExecutorView(APIView): | ||||
|             with ( | ||||
|                 start_span( | ||||
|                     op="authentik.flow.executor.stage", | ||||
|                     name=class_path, | ||||
|                     description=class_path, | ||||
|                 ) as span, | ||||
|                 HIST_FLOW_EXECUTION_STAGE_TIME.labels( | ||||
|                     method=request.method.upper(), | ||||
| @ -324,7 +324,7 @@ class FlowExecutorView(APIView): | ||||
|             with ( | ||||
|                 start_span( | ||||
|                     op="authentik.flow.executor.stage", | ||||
|                     name=class_path, | ||||
|                     description=class_path, | ||||
|                 ) as span, | ||||
|                 HIST_FLOW_EXECUTION_STAGE_TIME.labels( | ||||
|                     method=request.method.upper(), | ||||
| @ -441,7 +441,7 @@ class FlowExecutorView(APIView): | ||||
|             ) | ||||
|             return self.restart_flow(keep_context) | ||||
|         self.cancel() | ||||
|         challenge_view = AccessDeniedStage(self, error_message) | ||||
|         challenge_view = AccessDeniedChallengeView(self, error_message) | ||||
|         challenge_view.request = self.request | ||||
|         return to_stage_response(self.request, challenge_view.get(self.request)) | ||||
|  | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # update website/docs/install-config/configuration/configuration.mdx | ||||
| # update website/docs/installation/configuration.mdx | ||||
| # This is the default configuration file | ||||
| postgresql: | ||||
|   host: localhost | ||||
| @ -105,10 +105,6 @@ ldap: | ||||
|   tls: | ||||
|     ciphers: null | ||||
|  | ||||
| sources: | ||||
|   kerberos: | ||||
|     task_timeout_hours: 2 | ||||
|  | ||||
| reputation: | ||||
|   expiry: 86400 | ||||
|  | ||||
|  | ||||
| @ -21,14 +21,7 @@ class DebugSession(Session): | ||||
|  | ||||
|     def send(self, req: PreparedRequest, *args, **kwargs): | ||||
|         request_id = str(uuid4()) | ||||
|         LOGGER.debug( | ||||
|             "HTTP request sent", | ||||
|             uid=request_id, | ||||
|             url=req.url, | ||||
|             method=req.method, | ||||
|             headers=req.headers, | ||||
|             body=req.body, | ||||
|         ) | ||||
|         LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers) | ||||
|         resp = super().send(req, *args, **kwargs) | ||||
|         LOGGER.debug( | ||||
|             "HTTP response received", | ||||
|  | ||||
| @ -53,7 +53,7 @@ class ServiceConnectionInvalid(SentryIgnoredException): | ||||
| class OutpostConfig: | ||||
|     """Configuration an outpost uses to configure it self""" | ||||
|  | ||||
|     # update website/docs/add-secure-apps/outposts/_config.md | ||||
|     # update website/docs/outposts/_config.md | ||||
|  | ||||
|     authentik_host: str = "" | ||||
|     authentik_host_insecure: bool = False | ||||
|  | ||||
| @ -113,7 +113,7 @@ class PolicyEngine: | ||||
|         with ( | ||||
|             start_span( | ||||
|                 op="authentik.policy.engine.build", | ||||
|                 name=self.__pbm, | ||||
|                 description=self.__pbm, | ||||
|             ) as span, | ||||
|             HIST_POLICIES_ENGINE_TOTAL_TIME.labels( | ||||
|                 obj_type=class_to_path(self.__pbm.__class__), | ||||
|  | ||||
| @ -108,7 +108,7 @@ class EventMatcherPolicy(Policy): | ||||
|                 result=result, | ||||
|             ) | ||||
|             matches.append(result) | ||||
|         passing = all(x.passing for x in matches) | ||||
|         passing = any(x.passing for x in matches) | ||||
|         messages = chain(*[x.messages for x in matches]) | ||||
|         result = PolicyResult(passing, *messages) | ||||
|         result.source_results = matches | ||||
|  | ||||
| @ -77,24 +77,11 @@ class TestEventMatcherPolicy(TestCase): | ||||
|         request = PolicyRequest(get_anonymous_user()) | ||||
|         request.context["event"] = event | ||||
|         policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( | ||||
|             client_ip="1.2.3.5", app="foo" | ||||
|             client_ip="1.2.3.5", app="bar" | ||||
|         ) | ||||
|         response = policy.passes(request) | ||||
|         self.assertFalse(response.passing) | ||||
|  | ||||
|     def test_multiple(self): | ||||
|         """Test multiple""" | ||||
|         event = Event.new(EventAction.LOGIN) | ||||
|         event.app = "foo" | ||||
|         event.client_ip = "1.2.3.4" | ||||
|         request = PolicyRequest(get_anonymous_user()) | ||||
|         request.context["event"] = event | ||||
|         policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( | ||||
|             client_ip="1.2.3.4", app="foo" | ||||
|         ) | ||||
|         response = policy.passes(request) | ||||
|         self.assertTrue(response.passing) | ||||
|  | ||||
|     def test_invalid(self): | ||||
|         """Test passing event""" | ||||
|         request = PolicyRequest(get_anonymous_user()) | ||||
|  | ||||
| @ -87,7 +87,6 @@ class LDAPOutpostConfigSerializer(ModelSerializer): | ||||
|  | ||||
|     application_slug = SerializerMethodField() | ||||
|     bind_flow_slug = CharField(source="authorization_flow.slug") | ||||
|     unbind_flow_slug = SerializerMethodField() | ||||
|  | ||||
|     def get_application_slug(self, instance: LDAPProvider) -> str: | ||||
|         """Prioritise backchannel slug over direct application slug""" | ||||
| @ -95,16 +94,6 @@ class LDAPOutpostConfigSerializer(ModelSerializer): | ||||
|             return instance.backchannel_application.slug | ||||
|         return instance.application.slug | ||||
|  | ||||
|     def get_unbind_flow_slug(self, instance: LDAPProvider) -> str | None: | ||||
|         """Get slug for unbind flow, defaulting to brand's default flow.""" | ||||
|         flow = instance.invalidation_flow | ||||
|         if not flow and "request" in self.context: | ||||
|             request = self.context.get("request") | ||||
|             flow = request.brand.flow_invalidation | ||||
|         if not flow: | ||||
|             return None | ||||
|         return flow.slug | ||||
|  | ||||
|     class Meta: | ||||
|         model = LDAPProvider | ||||
|         fields = [ | ||||
| @ -112,7 +101,6 @@ class LDAPOutpostConfigSerializer(ModelSerializer): | ||||
|             "name", | ||||
|             "base_dn", | ||||
|             "bind_flow_slug", | ||||
|             "unbind_flow_slug", | ||||
|             "application_slug", | ||||
|             "certificate", | ||||
|             "tls_server_name", | ||||
|  | ||||
| @ -39,7 +39,6 @@ class OAuth2ProviderSerializer(ProviderSerializer): | ||||
|             "refresh_token_validity", | ||||
|             "include_claims_in_id_token", | ||||
|             "signing_key", | ||||
|             "encryption_key", | ||||
|             "redirect_uris", | ||||
|             "sub_mode", | ||||
|             "property_mappings", | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """id_token utils""" | ||||
|  | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from hashlib import sha256 | ||||
| from typing import TYPE_CHECKING, Any | ||||
|  | ||||
| from django.db import models | ||||
| @ -24,13 +23,8 @@ if TYPE_CHECKING: | ||||
|     from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider | ||||
|  | ||||
|  | ||||
| def hash_session_key(session_key: str) -> str: | ||||
|     """Hash the session key for inclusion in JWTs as `sid`""" | ||||
|     return sha256(session_key.encode("ascii")).hexdigest() | ||||
|  | ||||
|  | ||||
| class SubModes(models.TextChoices): | ||||
|     """Mode after which 'sub' attribute is generated, for compatibility reasons""" | ||||
|     """Mode after which 'sub' attribute is generateed, for compatibility reasons""" | ||||
|  | ||||
|     HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID") | ||||
|     USER_ID = "user_id", _("Based on user ID") | ||||
| @ -57,8 +51,7 @@ class IDToken: | ||||
|     and potentially other requested Claims. The ID Token is represented as a | ||||
|     JSON Web Token (JWT) [JWT]. | ||||
|  | ||||
|     https://openid.net/specs/openid-connect-core-1_0.html#IDToken | ||||
|     https://www.iana.org/assignments/jwt/jwt.xhtml""" | ||||
|     https://openid.net/specs/openid-connect-core-1_0.html#IDToken""" | ||||
|  | ||||
|     # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 | ||||
|     iss: str | None = None | ||||
| @ -86,8 +79,6 @@ class IDToken: | ||||
|     nonce: str | None = None | ||||
|     # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html | ||||
|     at_hash: str | None = None | ||||
|     # Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents | ||||
|     sid: str | None = None | ||||
|  | ||||
|     claims: dict[str, Any] = field(default_factory=dict) | ||||
|  | ||||
| @ -125,11 +116,9 @@ class IDToken: | ||||
|         now = timezone.now() | ||||
|         id_token.iat = int(now.timestamp()) | ||||
|         id_token.auth_time = int(token.auth_time.timestamp()) | ||||
|         if token.session: | ||||
|             id_token.sid = hash_session_key(token.session.session_key) | ||||
|  | ||||
|         # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time | ||||
|         auth_event = get_login_event(token.session) | ||||
|         auth_event = get_login_event(request) | ||||
|         if auth_event: | ||||
|             # Also check which method was used for authentication | ||||
|             method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") | ||||
|  | ||||
| @ -3,7 +3,6 @@ | ||||
| import django.db.models.deletion | ||||
| from django.apps.registry import Apps | ||||
| from django.db import migrations, models | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
|  | ||||
| import authentik.lib.utils.time | ||||
|  | ||||
| @ -15,7 +14,7 @@ scope_uid_map = { | ||||
| } | ||||
|  | ||||
|  | ||||
| def set_managed_flag(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
| def set_managed_flag(apps: Apps, schema_editor): | ||||
|     ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping") | ||||
|     db_alias = schema_editor.connection.alias | ||||
|     for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "): | ||||
|  | ||||
| @ -1,42 +0,0 @@ | ||||
| # Generated by Django 5.0.9 on 2024-10-16 14:53 | ||||
|  | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_crypto", "0004_alter_certificatekeypair_name"), | ||||
|         ( | ||||
|             "authentik_providers_oauth2", | ||||
|             "0020_remove_accesstoken_authentik_p_token_4bc870_idx_and_more", | ||||
|         ), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="oauth2provider", | ||||
|             name="encryption_key", | ||||
|             field=models.ForeignKey( | ||||
|                 help_text="Key used to encrypt the tokens. When set, tokens will be encrypted and returned as JWEs.", | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_NULL, | ||||
|                 related_name="oauth2provider_encryption_key_set", | ||||
|                 to="authentik_crypto.certificatekeypair", | ||||
|                 verbose_name="Encryption Key", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="oauth2provider", | ||||
|             name="signing_key", | ||||
|             field=models.ForeignKey( | ||||
|                 help_text="Key used to sign the tokens.", | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_NULL, | ||||
|                 related_name="oauth2provider_signing_key_set", | ||||
|                 to="authentik_crypto.certificatekeypair", | ||||
|                 verbose_name="Signing Key", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,113 +0,0 @@ | ||||
| # Generated by Django 5.0.9 on 2024-10-23 13:38 | ||||
|  | ||||
| from hashlib import sha256 | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
| from django.apps.registry import Apps | ||||
| from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||
| from authentik.lib.migrations import progress_bar | ||||
|  | ||||
|  | ||||
| def migrate_session(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||
|     AuthenticatedSession = apps.get_model("authentik_core", "authenticatedsession") | ||||
|     AuthorizationCode = apps.get_model("authentik_providers_oauth2", "authorizationcode") | ||||
|     AccessToken = apps.get_model("authentik_providers_oauth2", "accesstoken") | ||||
|     RefreshToken = apps.get_model("authentik_providers_oauth2", "refreshtoken") | ||||
|     db_alias = schema_editor.connection.alias | ||||
|  | ||||
|     print(f"\nFetching session keys, this might take a couple of minutes...") | ||||
|     session_ids = {} | ||||
|     for session in progress_bar(AuthenticatedSession.objects.using(db_alias).all()): | ||||
|         session_ids[sha256(session.session_key.encode("ascii")).hexdigest()] = session.session_key | ||||
|     for model in [AuthorizationCode, AccessToken, RefreshToken]: | ||||
|         print( | ||||
|             f"\nAdding session to {model._meta.verbose_name}, this might take a couple of minutes..." | ||||
|         ) | ||||
|         for code in progress_bar(model.objects.using(db_alias).all()): | ||||
|             if code.session_id_old not in session_ids: | ||||
|                 continue | ||||
|             code.session = ( | ||||
|                 AuthenticatedSession.objects.using(db_alias) | ||||
|                 .filter(session_key=session_ids[code.session_id_old]) | ||||
|                 .first() | ||||
|             ) | ||||
|             code.save() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0040_provider_invalidation_flow"), | ||||
|         ("authentik_providers_oauth2", "0021_oauth2provider_encryption_key_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RenameField( | ||||
|             model_name="accesstoken", | ||||
|             old_name="session_id", | ||||
|             new_name="session_id_old", | ||||
|         ), | ||||
|         migrations.RenameField( | ||||
|             model_name="authorizationcode", | ||||
|             old_name="session_id", | ||||
|             new_name="session_id_old", | ||||
|         ), | ||||
|         migrations.RenameField( | ||||
|             model_name="refreshtoken", | ||||
|             old_name="session_id", | ||||
|             new_name="session_id_old", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="accesstoken", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="authorizationcode", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="devicetoken", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="refreshtoken", | ||||
|             name="session", | ||||
|             field=models.ForeignKey( | ||||
|                 default=None, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.SET_DEFAULT, | ||||
|                 to="authentik_core.authenticatedsession", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.RunPython(migrate_session), | ||||
|         migrations.RemoveField( | ||||
|             model_name="accesstoken", | ||||
|             name="session_id_old", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="authorizationcode", | ||||
|             name="session_id_old", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="refreshtoken", | ||||
|             name="session_id_old", | ||||
|         ), | ||||
|     ] | ||||
| @ -18,21 +18,12 @@ from django.http import HttpRequest | ||||
| from django.templatetags.static import static | ||||
| from django.urls import reverse | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from jwcrypto.common import json_encode | ||||
| from jwcrypto.jwe import JWE | ||||
| from jwcrypto.jwk import JWK | ||||
| from jwt import encode | ||||
| from rest_framework.serializers import Serializer | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.brands.models import WebfingerProvider | ||||
| from authentik.core.models import ( | ||||
|     AuthenticatedSession, | ||||
|     ExpiringModel, | ||||
|     PropertyMapping, | ||||
|     Provider, | ||||
|     User, | ||||
| ) | ||||
| from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key | ||||
| from authentik.lib.models import SerializerModel | ||||
| @ -215,19 +206,9 @@ class OAuth2Provider(WebfingerProvider, Provider): | ||||
|         verbose_name=_("Signing Key"), | ||||
|         on_delete=models.SET_NULL, | ||||
|         null=True, | ||||
|         help_text=_("Key used to sign the tokens."), | ||||
|         related_name="oauth2provider_signing_key_set", | ||||
|     ) | ||||
|     encryption_key = models.ForeignKey( | ||||
|         CertificateKeyPair, | ||||
|         verbose_name=_("Encryption Key"), | ||||
|         on_delete=models.SET_NULL, | ||||
|         null=True, | ||||
|         help_text=_( | ||||
|             "Key used to encrypt the tokens. When set, " | ||||
|             "tokens will be encrypted and returned as JWEs." | ||||
|             "Key used to sign the tokens. Only required when JWT Algorithm is set to RS256." | ||||
|         ), | ||||
|         related_name="oauth2provider_encryption_key_set", | ||||
|     ) | ||||
|  | ||||
|     jwks_sources = models.ManyToManyField( | ||||
| @ -306,27 +287,7 @@ class OAuth2Provider(WebfingerProvider, Provider): | ||||
|         if self.signing_key: | ||||
|             headers["kid"] = self.signing_key.kid | ||||
|         key, alg = self.jwt_key | ||||
|         encoded = encode(payload, key, algorithm=alg, headers=headers) | ||||
|         if self.encryption_key: | ||||
|             return self.encrypt(encoded) | ||||
|         return encoded | ||||
|  | ||||
|     def encrypt(self, raw: str) -> str: | ||||
|         """Encrypt JWT""" | ||||
|         key = JWK.from_pem(self.encryption_key.certificate_data.encode()) | ||||
|         jwe = JWE( | ||||
|             raw, | ||||
|             json_encode( | ||||
|                 { | ||||
|                     "alg": "RSA-OAEP-256", | ||||
|                     "enc": "A256CBC-HS512", | ||||
|                     "typ": "JWE", | ||||
|                     "kid": self.encryption_key.kid, | ||||
|                 } | ||||
|             ), | ||||
|         ) | ||||
|         jwe.add_recipient(key) | ||||
|         return jwe.serialize(compact=True) | ||||
|         return encode(payload, key, algorithm=alg, headers=headers) | ||||
|  | ||||
|     def webfinger(self, resource: str, request: HttpRequest): | ||||
|         return { | ||||
| @ -359,9 +320,7 @@ class BaseGrantModel(models.Model): | ||||
|     revoked = models.BooleanField(default=False) | ||||
|     _scope = models.TextField(default="", verbose_name=_("Scopes")) | ||||
|     auth_time = models.DateTimeField(verbose_name="Authentication time") | ||||
|     session = models.ForeignKey( | ||||
|         AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None | ||||
|     ) | ||||
|     session_id = models.CharField(default="", blank=True) | ||||
|  | ||||
|     class Meta: | ||||
|         abstract = True | ||||
| @ -499,9 +458,6 @@ class DeviceToken(ExpiringModel): | ||||
|     device_code = models.TextField(default=generate_key) | ||||
|     user_code = models.TextField(default=generate_code_fixed_length) | ||||
|     _scope = models.TextField(default="", verbose_name=_("Scopes")) | ||||
|     session = models.ForeignKey( | ||||
|         AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def scope(self) -> list[str]: | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| from hashlib import sha256 | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
| @ -11,4 +13,5 @@ def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User, | ||||
|     """Revoke access tokens upon user logout""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete() | ||||
|     hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest() | ||||
|     AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete() | ||||
|  | ||||
| @ -412,73 +412,6 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 delta=5, | ||||
|             ) | ||||
|  | ||||
|     @apply_blueprint("system/providers-oauth2.yaml") | ||||
|     def test_full_implicit_enc(self): | ||||
|         """Test full authorization with encryption""" | ||||
|         flow = create_test_flow() | ||||
|         provider: OAuth2Provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=flow, | ||||
|             redirect_uris="http://localhost", | ||||
|             signing_key=self.keypair, | ||||
|             encryption_key=self.keypair, | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
|             ScopeMapping.objects.filter( | ||||
|                 managed__in=[ | ||||
|                     "goauthentik.io/providers/oauth2/scope-openid", | ||||
|                     "goauthentik.io/providers/oauth2/scope-email", | ||||
|                     "goauthentik.io/providers/oauth2/scope-profile", | ||||
|                 ] | ||||
|             ) | ||||
|         ) | ||||
|         provider.property_mappings.add( | ||||
|             ScopeMapping.objects.create( | ||||
|                 name=generate_id(), scope_name="test", expression="""return {"sub": "foo"}""" | ||||
|             ) | ||||
|         ) | ||||
|         Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||
|         state = generate_id() | ||||
|         user = create_test_admin_user() | ||||
|         self.client.force_login(user) | ||||
|         with patch( | ||||
|             "authentik.providers.oauth2.id_token.get_login_event", | ||||
|             MagicMock( | ||||
|                 return_value=Event( | ||||
|                     action=EventAction.LOGIN, | ||||
|                     context={PLAN_CONTEXT_METHOD: "password"}, | ||||
|                     created=now(), | ||||
|                 ) | ||||
|             ), | ||||
|         ): | ||||
|             # Step 1, initiate params and get redirect to flow | ||||
|             self.client.get( | ||||
|                 reverse("authentik_providers_oauth2:authorize"), | ||||
|                 data={ | ||||
|                     "response_type": "id_token", | ||||
|                     "client_id": "test", | ||||
|                     "state": state, | ||||
|                     "scope": "openid test", | ||||
|                     "redirect_uri": "http://localhost", | ||||
|                     "nonce": generate_id(), | ||||
|                 }, | ||||
|             ) | ||||
|             response = self.client.get( | ||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), | ||||
|             ) | ||||
|             self.assertEqual(response.status_code, 200) | ||||
|             token: AccessToken = AccessToken.objects.filter(user=user).first() | ||||
|             expires = timedelta_from_string(provider.access_token_validity).total_seconds() | ||||
|             jwt = self.validate_jwe(token, provider) | ||||
|             self.assertEqual(jwt["amr"], ["pwd"]) | ||||
|             self.assertEqual(jwt["sub"], "foo") | ||||
|             self.assertAlmostEqual( | ||||
|                 jwt["exp"] - now().timestamp(), | ||||
|                 expires, | ||||
|                 delta=5, | ||||
|             ) | ||||
|  | ||||
|     def test_full_fragment_code(self): | ||||
|         """Test full authorization""" | ||||
|         flow = create_test_flow() | ||||
|  | ||||
| @ -93,24 +93,6 @@ class TestJWKS(OAuthTestCase): | ||||
|         self.assertEqual(len(body["keys"]), 1) | ||||
|         PyJWKSet.from_dict(body) | ||||
|  | ||||
|     def test_enc(self): | ||||
|         """Test with JWE""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name="test", | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             signing_key=create_test_cert(PrivateKeyAlg.ECDSA), | ||||
|             encryption_key=create_test_cert(PrivateKeyAlg.ECDSA), | ||||
|         ) | ||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug}) | ||||
|         ) | ||||
|         body = json.loads(response.content.decode()) | ||||
|         self.assertEqual(len(body["keys"]), 2) | ||||
|         PyJWKSet.from_dict(body) | ||||
|  | ||||
|     def test_ecdsa_coords_mismatched(self): | ||||
|         """Test JWKS request with ES256""" | ||||
|         cert = CertificateKeyPair.objects.create( | ||||
|  | ||||
| @ -152,36 +152,6 @@ class TestToken(OAuthTestCase): | ||||
|         ) | ||||
|         self.validate_jwt(access, provider) | ||||
|  | ||||
|     def test_auth_code_enc(self): | ||||
|         """test request param""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid", | ||||
|             signing_key=self.keypair, | ||||
|             encryption_key=self.keypair, | ||||
|         ) | ||||
|         # Needs to be assigned to an application for iss to be set | ||||
|         self.app.provider = provider | ||||
|         self.app.save() | ||||
|         header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() | ||||
|         user = create_test_admin_user() | ||||
|         code = AuthorizationCode.objects.create( | ||||
|             code="foobar", provider=provider, user=user, auth_time=timezone.now() | ||||
|         ) | ||||
|         response = self.client.post( | ||||
|             reverse("authentik_providers_oauth2:token"), | ||||
|             data={ | ||||
|                 "grant_type": GRANT_TYPE_AUTHORIZATION_CODE, | ||||
|                 "code": code.code, | ||||
|                 "redirect_uri": "http://local.invalid", | ||||
|             }, | ||||
|             HTTP_AUTHORIZATION=f"Basic {header}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first() | ||||
|         self.validate_jwe(access, provider) | ||||
|  | ||||
|     @apply_blueprint("system/providers-oauth2.yaml") | ||||
|     def test_refresh_token_view(self): | ||||
|         """test request param""" | ||||
|  | ||||
| @ -34,7 +34,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase): | ||||
|         self.factory = RequestFactory() | ||||
|         self.cert = create_test_cert() | ||||
|  | ||||
|         jwk = JWKSView().get_jwk_for_key(self.cert, "sig") | ||||
|         jwk = JWKSView().get_jwk_for_key(self.cert) | ||||
|         self.source: OAuthSource = OAuthSource.objects.create( | ||||
|             name=generate_id(), | ||||
|             slug=generate_id(), | ||||
|  | ||||
| @ -3,8 +3,6 @@ | ||||
| from typing import Any | ||||
|  | ||||
| from django.test import TestCase | ||||
| from jwcrypto.jwe import JWE | ||||
| from jwcrypto.jwk import JWK | ||||
| from jwt import decode | ||||
|  | ||||
| from authentik.core.tests.utils import create_test_cert | ||||
| @ -34,15 +32,6 @@ class OAuthTestCase(TestCase): | ||||
|         if key in container: | ||||
|             self.assertIsNotNone(container[key]) | ||||
|  | ||||
|     def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: | ||||
|         """Validate JWEs""" | ||||
|         private_key = JWK.from_pem(provider.encryption_key.key_data.encode()) | ||||
|  | ||||
|         jwetoken = JWE() | ||||
|         jwetoken.deserialize(token.token, key=private_key) | ||||
|         token.token = jwetoken.payload.decode() | ||||
|         return self.validate_jwt(token, provider) | ||||
|  | ||||
|     def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: | ||||
|         """Validate that all required fields are set""" | ||||
|         key, alg = provider.jwt_key | ||||
|  | ||||
| @ -12,7 +12,6 @@ from authentik.providers.oauth2.api.tokens import ( | ||||
| ) | ||||
| from authentik.providers.oauth2.views.authorize import AuthorizationFlowInitView | ||||
| from authentik.providers.oauth2.views.device_backchannel import DeviceView | ||||
| from authentik.providers.oauth2.views.end_session import EndSessionView | ||||
| from authentik.providers.oauth2.views.introspection import TokenIntrospectionView | ||||
| from authentik.providers.oauth2.views.jwks import JWKSView | ||||
| from authentik.providers.oauth2.views.provider import ProviderInfoView | ||||
| @ -45,7 +44,7 @@ urlpatterns = [ | ||||
|     ), | ||||
|     path( | ||||
|         "<slug:application_slug>/end-session/", | ||||
|         EndSessionView.as_view(), | ||||
|         RedirectView.as_view(pattern_name="authentik_core:if-session-end", query_string=True), | ||||
|         name="end-session", | ||||
|     ), | ||||
|     path("<slug:application_slug>/jwks/", JWKSView.as_view(), name="jwks"), | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| from dataclasses import InitVar, dataclass, field | ||||
| from datetime import timedelta | ||||
| from hashlib import sha256 | ||||
| from json import dumps | ||||
| from re import error as RegexError | ||||
| from re import fullmatch | ||||
| @ -15,7 +16,7 @@ from django.utils import timezone | ||||
| from django.utils.translation import gettext as _ | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import Application, AuthenticatedSession | ||||
| from authentik.core.models import Application | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.signals import get_login_event | ||||
| from authentik.flows.challenge import ( | ||||
| @ -317,9 +318,7 @@ class OAuthAuthorizationParams: | ||||
|             expires=now + timedelta_from_string(self.provider.access_code_validity), | ||||
|             scope=self.scope, | ||||
|             nonce=self.nonce, | ||||
|             session=AuthenticatedSession.objects.filter( | ||||
|                 session_key=request.session.session_key | ||||
|             ).first(), | ||||
|             session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(), | ||||
|         ) | ||||
|  | ||||
|         if self.code_challenge and self.code_challenge_method: | ||||
| @ -611,9 +610,7 @@ class OAuthFulfillmentStage(StageView): | ||||
|             expires=access_token_expiry, | ||||
|             provider=self.provider, | ||||
|             auth_time=auth_event.created if auth_event else now, | ||||
|             session=AuthenticatedSession.objects.filter( | ||||
|                 session_key=self.request.session.session_key | ||||
|             ).first(), | ||||
|             session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(), | ||||
|         ) | ||||
|  | ||||
|         id_token = IDToken.new(self.provider, token, self.request) | ||||
|  | ||||
| @ -1,45 +0,0 @@ | ||||
| """oauth2 provider end_session Views""" | ||||
|  | ||||
| from django.http import Http404, HttpRequest, HttpResponse | ||||
| from django.shortcuts import get_object_or_404 | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.flows.models import Flow, in_memory_stage | ||||
| from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner | ||||
| from authentik.flows.stage import SessionEndStage | ||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
| from authentik.lib.utils.urls import redirect_with_qs | ||||
| from authentik.policies.views import PolicyAccessView | ||||
|  | ||||
|  | ||||
| class EndSessionView(PolicyAccessView): | ||||
|     """Redirect to application's provider's invalidation flow""" | ||||
|  | ||||
|     flow: Flow | ||||
|  | ||||
|     def resolve_provider_application(self): | ||||
|         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) | ||||
|         self.provider = self.application.get_provider() | ||||
|         if not self.provider: | ||||
|             raise Http404 | ||||
|         self.flow = self.provider.invalidation_flow or self.request.brand.flow_invalidation | ||||
|         if not self.flow: | ||||
|             raise Http404 | ||||
|  | ||||
|     def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: | ||||
|         """Dispatch the flow planner for the invalidation flow""" | ||||
|         planner = FlowPlanner(self.flow) | ||||
|         planner.allow_empty_flows = True | ||||
|         plan = planner.plan( | ||||
|             request, | ||||
|             { | ||||
|                 PLAN_CONTEXT_APPLICATION: self.application, | ||||
|             }, | ||||
|         ) | ||||
|         plan.insert_stage(in_memory_stage(SessionEndStage)) | ||||
|         request.session[SESSION_KEY_PLAN] = plan | ||||
|         return redirect_with_qs( | ||||
|             "authentik_core:if-flow", | ||||
|             self.request.GET, | ||||
|             flow_slug=self.flow.slug, | ||||
|         ) | ||||
| @ -64,42 +64,36 @@ def to_base64url_uint(val: int, min_length: int = 0) -> bytes: | ||||
| class JWKSView(View): | ||||
|     """Show RSA Key data for Provider""" | ||||
|  | ||||
|     def get_jwk_for_key(self, key: CertificateKeyPair, use: str) -> dict | None: | ||||
|     def get_jwk_for_key(self, key: CertificateKeyPair) -> dict | None: | ||||
|         """Convert a certificate-key pair into JWK""" | ||||
|         private_key = key.private_key | ||||
|         key_data = None | ||||
|         if not private_key: | ||||
|             return key_data | ||||
|  | ||||
|         key_data = {} | ||||
|  | ||||
|         if use == "sig": | ||||
|             if isinstance(private_key, RSAPrivateKey): | ||||
|                 key_data["alg"] = JWTAlgorithms.RS256 | ||||
|             elif isinstance(private_key, EllipticCurvePrivateKey): | ||||
|                 key_data["alg"] = JWTAlgorithms.ES256 | ||||
|         elif use == "enc": | ||||
|             key_data["alg"] = "RSA-OAEP-256" | ||||
|             key_data["enc"] = "A256CBC-HS512" | ||||
|  | ||||
|         if isinstance(private_key, RSAPrivateKey): | ||||
|             public_key: RSAPublicKey = private_key.public_key() | ||||
|             public_numbers = public_key.public_numbers() | ||||
|             key_data["kid"] = key.kid | ||||
|             key_data["kty"] = "RSA" | ||||
|             key_data["use"] = use | ||||
|             key_data["n"] = to_base64url_uint(public_numbers.n).decode() | ||||
|             key_data["e"] = to_base64url_uint(public_numbers.e).decode() | ||||
|             key_data = { | ||||
|                 "kid": key.kid, | ||||
|                 "kty": "RSA", | ||||
|                 "alg": JWTAlgorithms.RS256, | ||||
|                 "use": "sig", | ||||
|                 "n": to_base64url_uint(public_numbers.n).decode(), | ||||
|                 "e": to_base64url_uint(public_numbers.e).decode(), | ||||
|             } | ||||
|         elif isinstance(private_key, EllipticCurvePrivateKey): | ||||
|             public_key: EllipticCurvePublicKey = private_key.public_key() | ||||
|             public_numbers = public_key.public_numbers() | ||||
|             curve_type = type(public_key.curve) | ||||
|             key_data["kid"] = key.kid | ||||
|             key_data["kty"] = "EC" | ||||
|             key_data["use"] = use | ||||
|             key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode() | ||||
|             key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode() | ||||
|             key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name) | ||||
|             key_data = { | ||||
|                 "kid": key.kid, | ||||
|                 "kty": "EC", | ||||
|                 "alg": JWTAlgorithms.ES256, | ||||
|                 "use": "sig", | ||||
|                 "x": to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode(), | ||||
|                 "y": to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode(), | ||||
|                 "crv": ec_crv_map.get(curve_type, public_key.curve.name), | ||||
|             } | ||||
|         else: | ||||
|             return key_data | ||||
|         key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")] | ||||
| @ -119,19 +113,14 @@ class JWKSView(View): | ||||
|         """Show JWK Key data for Provider""" | ||||
|         application = get_object_or_404(Application, slug=application_slug) | ||||
|         provider: OAuth2Provider = get_object_or_404(OAuth2Provider, pk=application.provider_id) | ||||
|         signing_key: CertificateKeyPair = provider.signing_key | ||||
|  | ||||
|         response_data = {} | ||||
|  | ||||
|         if signing_key := provider.signing_key: | ||||
|             jwk = self.get_jwk_for_key(signing_key, "sig") | ||||
|         if signing_key: | ||||
|             jwk = self.get_jwk_for_key(signing_key) | ||||
|             if jwk: | ||||
|                 response_data.setdefault("keys", []) | ||||
|                 response_data["keys"].append(jwk) | ||||
|         if encryption_key := provider.encryption_key: | ||||
|             jwk = self.get_jwk_for_key(encryption_key, "enc") | ||||
|             if jwk: | ||||
|                 response_data.setdefault("keys", []) | ||||
|                 response_data["keys"].append(jwk) | ||||
|                 response_data["keys"] = [jwk] | ||||
|  | ||||
|         response = JsonResponse(response_data) | ||||
|         response["Access-Control-Allow-Origin"] = "*" | ||||
|  | ||||
| @ -46,7 +46,7 @@ class ProviderInfoView(View): | ||||
|         if SCOPE_OPENID not in scopes: | ||||
|             scopes.append(SCOPE_OPENID) | ||||
|         _, supported_alg = provider.jwt_key | ||||
|         config = { | ||||
|         return { | ||||
|             "issuer": provider.get_issuer(self.request), | ||||
|             "authorization_endpoint": self.request.build_absolute_uri( | ||||
|                 reverse("authentik_providers_oauth2:authorize") | ||||
| @ -114,10 +114,6 @@ class ProviderInfoView(View): | ||||
|             "claims_parameter_supported": False, | ||||
|             "code_challenge_methods_supported": [PKCE_METHOD_PLAIN, PKCE_METHOD_S256], | ||||
|         } | ||||
|         if provider.encryption_key: | ||||
|             config["id_token_encryption_alg_values_supported"] = ["RSA-OAEP-256"] | ||||
|             config["id_token_encryption_enc_values_supported"] = ["A256CBC-HS512"] | ||||
|         return config | ||||
|  | ||||
|     def get_claims(self, provider: OAuth2Provider) -> list[str]: | ||||
|         """Get a list of supported claims based on configured scope mappings""" | ||||
|  | ||||
| @ -439,14 +439,15 @@ class TokenParams: | ||||
|                 # (22 chars being the length of the "template") | ||||
|                 username=f"ak-{self.provider.name[:150-22]}-client_credentials", | ||||
|                 defaults={ | ||||
|                     "attributes": { | ||||
|                         USER_ATTRIBUTE_GENERATED: True, | ||||
|                     }, | ||||
|                     "last_login": timezone.now(), | ||||
|                     "name": f"Autogenerated user from application {app.name} (client credentials)", | ||||
|                     "path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}", | ||||
|                     "type": UserTypes.SERVICE_ACCOUNT, | ||||
|                 }, | ||||
|             ) | ||||
|             self.user.attributes[USER_ATTRIBUTE_GENERATED] = True | ||||
|             self.user.save() | ||||
|         self.__check_policy_access(app, request) | ||||
|  | ||||
|         Event.new( | ||||
| @ -470,6 +471,9 @@ class TokenParams: | ||||
|             self.user, created = User.objects.update_or_create( | ||||
|                 username=f"{self.provider.name}-{token.get('sub')}", | ||||
|                 defaults={ | ||||
|                     "attributes": { | ||||
|                         USER_ATTRIBUTE_GENERATED: True, | ||||
|                     }, | ||||
|                     "last_login": timezone.now(), | ||||
|                     "name": ( | ||||
|                         f"Autogenerated user from application {app.name} (client credentials JWT)" | ||||
| @ -478,8 +482,6 @@ class TokenParams: | ||||
|                     "type": UserTypes.SERVICE_ACCOUNT, | ||||
|                 }, | ||||
|             ) | ||||
|             self.user.attributes[USER_ATTRIBUTE_GENERATED] = True | ||||
|             self.user.save() | ||||
|             exp = token.get("exp") | ||||
|             if created and exp: | ||||
|                 self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp | ||||
| @ -550,7 +552,7 @@ class TokenView(View): | ||||
|             # Keep same scopes as previous token | ||||
|             scope=self.params.authorization_code.scope, | ||||
|             auth_time=self.params.authorization_code.auth_time, | ||||
|             session=self.params.authorization_code.session, | ||||
|             session_id=self.params.authorization_code.session_id, | ||||
|         ) | ||||
|         access_id_token = IDToken.new( | ||||
|             self.provider, | ||||
| @ -578,7 +580,7 @@ class TokenView(View): | ||||
|                 expires=refresh_token_expiry, | ||||
|                 provider=self.provider, | ||||
|                 auth_time=self.params.authorization_code.auth_time, | ||||
|                 session=self.params.authorization_code.session, | ||||
|                 session_id=self.params.authorization_code.session_id, | ||||
|             ) | ||||
|             id_token = IDToken.new( | ||||
|                 self.provider, | ||||
| @ -611,7 +613,7 @@ class TokenView(View): | ||||
|             # Keep same scopes as previous token | ||||
|             scope=self.params.refresh_token.scope, | ||||
|             auth_time=self.params.refresh_token.auth_time, | ||||
|             session=self.params.refresh_token.session, | ||||
|             session_id=self.params.refresh_token.session_id, | ||||
|         ) | ||||
|         access_token.id_token = IDToken.new( | ||||
|             self.provider, | ||||
| @ -627,7 +629,7 @@ class TokenView(View): | ||||
|             expires=refresh_token_expiry, | ||||
|             provider=self.provider, | ||||
|             auth_time=self.params.refresh_token.auth_time, | ||||
|             session=self.params.refresh_token.session, | ||||
|             session_id=self.params.refresh_token.session_id, | ||||
|         ) | ||||
|         id_token = IDToken.new( | ||||
|             self.provider, | ||||
| @ -685,14 +687,13 @@ class TokenView(View): | ||||
|             raise DeviceCodeError("authorization_pending") | ||||
|         now = timezone.now() | ||||
|         access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity) | ||||
|         auth_event = get_login_event(self.params.device_code.session) | ||||
|         auth_event = get_login_event(self.request) | ||||
|         access_token = AccessToken( | ||||
|             provider=self.provider, | ||||
|             user=self.params.device_code.user, | ||||
|             expires=access_token_expiry, | ||||
|             scope=self.params.device_code.scope, | ||||
|             auth_time=auth_event.created if auth_event else now, | ||||
|             session=self.params.device_code.session, | ||||
|         ) | ||||
|         access_token.id_token = IDToken.new( | ||||
|             self.provider, | ||||
|  | ||||
| @ -1,12 +1,13 @@ | ||||
| """proxy provider tasks""" | ||||
|  | ||||
| from hashlib import sha256 | ||||
|  | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
|  | ||||
| from authentik.outposts.consumer import OUTPOST_GROUP | ||||
| from authentik.outposts.models import Outpost, OutpostType | ||||
| from authentik.providers.oauth2.id_token import hash_session_key | ||||
| from authentik.providers.proxy.models import ProxyProvider | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| @ -25,7 +26,7 @@ def proxy_set_defaults(): | ||||
| def proxy_on_logout(session_id: str): | ||||
|     """Update outpost instances connected to a single outpost""" | ||||
|     layer = get_channel_layer() | ||||
|     hashed_session_id = hash_session_key(session_id) | ||||
|     hashed_session_id = sha256(session_id.encode("ascii")).hexdigest() | ||||
|     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): | ||||
|         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||
|         async_to_sync(layer.group_send)( | ||||
|  | ||||
| @ -24,7 +24,6 @@ class ProxyProviderTests(APITestCase): | ||||
|                 "name": generate_id(), | ||||
|                 "mode": ProxyMode.PROXY, | ||||
|                 "authorization_flow": create_test_flow().pk.hex, | ||||
|                 "invalidation_flow": create_test_flow().pk.hex, | ||||
|                 "external_host": "http://localhost", | ||||
|                 "internal_host": "http://localhost", | ||||
|                 "basic_auth_enabled": True, | ||||
| @ -42,7 +41,6 @@ class ProxyProviderTests(APITestCase): | ||||
|                 "name": generate_id(), | ||||
|                 "mode": ProxyMode.PROXY, | ||||
|                 "authorization_flow": create_test_flow().pk.hex, | ||||
|                 "invalidation_flow": create_test_flow().pk.hex, | ||||
|                 "external_host": "http://localhost", | ||||
|                 "internal_host": "http://localhost", | ||||
|                 "basic_auth_enabled": True, | ||||
| @ -66,7 +64,6 @@ class ProxyProviderTests(APITestCase): | ||||
|                 "name": generate_id(), | ||||
|                 "mode": ProxyMode.PROXY, | ||||
|                 "authorization_flow": create_test_flow().pk.hex, | ||||
|                 "invalidation_flow": create_test_flow().pk.hex, | ||||
|                 "external_host": "http://localhost", | ||||
|             }, | ||||
|         ) | ||||
| @ -85,7 +82,6 @@ class ProxyProviderTests(APITestCase): | ||||
|                 "name": name, | ||||
|                 "mode": ProxyMode.PROXY, | ||||
|                 "authorization_flow": create_test_flow().pk.hex, | ||||
|                 "invalidation_flow": create_test_flow().pk.hex, | ||||
|                 "external_host": "http://localhost", | ||||
|                 "internal_host": "http://localhost", | ||||
|             }, | ||||
| @ -103,7 +99,6 @@ class ProxyProviderTests(APITestCase): | ||||
|                 "name": name, | ||||
|                 "mode": ProxyMode.PROXY, | ||||
|                 "authorization_flow": create_test_flow().pk.hex, | ||||
|                 "invalidation_flow": create_test_flow().pk.hex, | ||||
|                 "external_host": "http://localhost", | ||||
|                 "internal_host": "http://localhost", | ||||
|             }, | ||||
| @ -119,7 +114,6 @@ class ProxyProviderTests(APITestCase): | ||||
|                 "name": name, | ||||
|                 "mode": ProxyMode.PROXY, | ||||
|                 "authorization_flow": create_test_flow().pk.hex, | ||||
|                 "invalidation_flow": create_test_flow().pk.hex, | ||||
|                 "external_host": "http://localhost", | ||||
|                 "internal_host": "http://localhost", | ||||
|             }, | ||||
|  | ||||
| @ -188,9 +188,6 @@ class SAMLProviderImportSerializer(PassiveSerializer): | ||||
|     authorization_flow = PrimaryKeyRelatedField( | ||||
|         queryset=Flow.objects.filter(designation=FlowDesignation.AUTHORIZATION), | ||||
|     ) | ||||
|     invalidation_flow = PrimaryKeyRelatedField( | ||||
|         queryset=Flow.objects.filter(designation=FlowDesignation.INVALIDATION), | ||||
|     ) | ||||
|     file = FileField() | ||||
|  | ||||
|  | ||||
| @ -280,9 +277,7 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet): | ||||
|         try: | ||||
|             metadata = ServiceProviderMetadataParser().parse(file.read().decode()) | ||||
|             metadata.to_provider( | ||||
|                 data.validated_data["name"], | ||||
|                 data.validated_data["authorization_flow"], | ||||
|                 data.validated_data["invalidation_flow"], | ||||
|                 data.validated_data["name"], data.validated_data["authorization_flow"] | ||||
|             ) | ||||
|         except ValueError as exc:  # pragma: no cover | ||||
|             LOGGER.warning(str(exc)) | ||||
|  | ||||
| @ -50,7 +50,6 @@ class AssertionProcessor: | ||||
|  | ||||
|     _issue_instant: str | ||||
|     _assertion_id: str | ||||
|     _response_id: str | ||||
|  | ||||
|     _valid_not_before: str | ||||
|     _session_not_on_or_after: str | ||||
| @ -63,7 +62,6 @@ class AssertionProcessor: | ||||
|  | ||||
|         self._issue_instant = get_time_string() | ||||
|         self._assertion_id = get_random_id() | ||||
|         self._response_id = get_random_id() | ||||
|  | ||||
|         self._valid_not_before = get_time_string( | ||||
|             timedelta_from_string(self.provider.assertion_valid_not_before) | ||||
| @ -132,9 +130,7 @@ class AssertionProcessor: | ||||
|         """Generate AuthnStatement with AuthnContext and ContextClassRef Elements.""" | ||||
|         auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement") | ||||
|         auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before | ||||
|         auth_n_statement.attrib["SessionIndex"] = sha256( | ||||
|             self.http_request.session.session_key.encode("ascii") | ||||
|         ).hexdigest() | ||||
|         auth_n_statement.attrib["SessionIndex"] = self._assertion_id | ||||
|         auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after | ||||
|  | ||||
|         auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext") | ||||
| @ -289,7 +285,7 @@ class AssertionProcessor: | ||||
|         response.attrib["Version"] = "2.0" | ||||
|         response.attrib["IssueInstant"] = self._issue_instant | ||||
|         response.attrib["Destination"] = self.provider.acs_url | ||||
|         response.attrib["ID"] = self._response_id | ||||
|         response.attrib["ID"] = get_random_id() | ||||
|         if self.auth_n_request.id: | ||||
|             response.attrib["InResponseTo"] = self.auth_n_request.id | ||||
|  | ||||
| @ -312,7 +308,7 @@ class AssertionProcessor: | ||||
|         ref = xmlsec.template.add_reference( | ||||
|             signature_node, | ||||
|             digest_algorithm_transform, | ||||
|             uri="#" + element.attrib["ID"], | ||||
|             uri="#" + self._assertion_id, | ||||
|         ) | ||||
|         xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped) | ||||
|         xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N) | ||||
|  | ||||
| @ -49,13 +49,12 @@ class ServiceProviderMetadata: | ||||
|  | ||||
|     signing_keypair: CertificateKeyPair | None = None | ||||
|  | ||||
|     def to_provider( | ||||
|         self, name: str, authorization_flow: Flow, invalidation_flow: Flow | ||||
|     ) -> SAMLProvider: | ||||
|     def to_provider(self, name: str, authorization_flow: Flow) -> SAMLProvider: | ||||
|         """Create a SAMLProvider instance from the details. `name` is required, | ||||
|         as depending on the metadata CertificateKeypairs might have to be created.""" | ||||
|         provider = SAMLProvider.objects.create( | ||||
|             name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow | ||||
|             name=name, | ||||
|             authorization_flow=authorization_flow, | ||||
|         ) | ||||
|         provider.issuer = self.entity_id | ||||
|         provider.sp_binding = self.acs_binding | ||||
|  | ||||
| @ -47,12 +47,11 @@ class TestSAMLProviderAPI(APITestCase): | ||||
|             data={ | ||||
|                 "name": generate_id(), | ||||
|                 "authorization_flow": create_test_flow().pk, | ||||
|                 "invalidation_flow": create_test_flow().pk, | ||||
|                 "acs_url": "http://localhost", | ||||
|                 "signing_kp": cert.pk, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         self.assertEqual(400, response.status_code) | ||||
|         self.assertJSONEqual( | ||||
|             response.content, | ||||
|             { | ||||
| @ -69,13 +68,12 @@ class TestSAMLProviderAPI(APITestCase): | ||||
|             data={ | ||||
|                 "name": generate_id(), | ||||
|                 "authorization_flow": create_test_flow().pk, | ||||
|                 "invalidation_flow": create_test_flow().pk, | ||||
|                 "acs_url": "http://localhost", | ||||
|                 "signing_kp": cert.pk, | ||||
|                 "sign_assertion": True, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         self.assertEqual(201, response.status_code) | ||||
|  | ||||
|     def test_metadata(self): | ||||
|         """Test metadata export (normal)""" | ||||
| @ -133,7 +131,6 @@ class TestSAMLProviderAPI(APITestCase): | ||||
|                     "file": metadata, | ||||
|                     "name": generate_id(), | ||||
|                     "authorization_flow": create_test_flow(FlowDesignation.AUTHORIZATION).pk, | ||||
|                     "invalidation_flow": create_test_flow(FlowDesignation.INVALIDATION).pk, | ||||
|                 }, | ||||
|                 format="multipart", | ||||
|             ) | ||||
|  | ||||
| @ -180,10 +180,6 @@ class TestAuthNRequest(TestCase): | ||||
|         # Now create a response and convert it to string (provider) | ||||
|         response_proc = AssertionProcessor(self.provider, http_request, parsed_request) | ||||
|         response = response_proc.build_response() | ||||
|         # Ensure both response and assertion ID are in the response twice (once as ID attribute, | ||||
|         # once as ds:Reference URI) | ||||
|         self.assertEqual(response.count(response_proc._assertion_id), 2) | ||||
|         self.assertEqual(response.count(response_proc._response_id), 2) | ||||
|  | ||||
|         # Now parse the response (source) | ||||
|         http_request.POST = QueryDict(mutable=True) | ||||
|  | ||||
| @ -82,7 +82,7 @@ class TestServiceProviderMetadataParser(TestCase): | ||||
|     def test_simple(self): | ||||
|         """Test simple metadata without Signing""" | ||||
|         metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/simple.xml")) | ||||
|         provider = metadata.to_provider("test", self.flow, self.flow) | ||||
|         provider = metadata.to_provider("test", self.flow) | ||||
|         self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs") | ||||
|         self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata") | ||||
|         self.assertEqual(provider.sp_binding, SAMLBindings.POST) | ||||
| @ -95,7 +95,7 @@ class TestServiceProviderMetadataParser(TestCase): | ||||
|         """Test Metadata with signing cert""" | ||||
|         create_test_cert() | ||||
|         metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/cert.xml")) | ||||
|         provider = metadata.to_provider("test", self.flow, self.flow) | ||||
|         provider = metadata.to_provider("test", self.flow) | ||||
|         self.assertEqual(provider.acs_url, "http://localhost:8080/apps/user_saml/saml/acs") | ||||
|         self.assertEqual(provider.issuer, "http://localhost:8080/apps/user_saml/saml/metadata") | ||||
|         self.assertEqual(provider.sp_binding, SAMLBindings.POST) | ||||
|  | ||||
| @ -1,8 +1,8 @@ | ||||
| """SLO Views""" | ||||
|  | ||||
| from django.http import Http404, HttpRequest | ||||
| from django.http import HttpRequest | ||||
| from django.http.response import HttpResponse | ||||
| from django.shortcuts import get_object_or_404 | ||||
| from django.shortcuts import get_object_or_404, redirect | ||||
| from django.utils.decorators import method_decorator | ||||
| from django.views.decorators.clickjacking import xframe_options_sameorigin | ||||
| from django.views.decorators.csrf import csrf_exempt | ||||
| @ -10,11 +10,6 @@ from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.flows.models import Flow, in_memory_stage | ||||
| from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner | ||||
| from authentik.flows.stage import SessionEndStage | ||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
| from authentik.lib.utils.urls import redirect_with_qs | ||||
| from authentik.lib.views import bad_request_message | ||||
| from authentik.policies.views import PolicyAccessView | ||||
| from authentik.providers.saml.exceptions import CannotHandleAssertion | ||||
| @ -33,16 +28,11 @@ class SAMLSLOView(PolicyAccessView): | ||||
|     """ "SAML SLO Base View, which plans a flow and injects our final stage. | ||||
|     Calls get/post handler.""" | ||||
|  | ||||
|     flow: Flow | ||||
|  | ||||
|     def resolve_provider_application(self): | ||||
|         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) | ||||
|         self.provider: SAMLProvider = get_object_or_404( | ||||
|             SAMLProvider, pk=self.application.provider_id | ||||
|         ) | ||||
|         self.flow = self.provider.invalidation_flow or self.request.brand.flow_invalidation | ||||
|         if not self.flow: | ||||
|             raise Http404 | ||||
|  | ||||
|     def check_saml_request(self) -> HttpRequest | None: | ||||
|         """Handler to verify the SAML Request. Must be implemented by a subclass""" | ||||
| @ -55,20 +45,9 @@ class SAMLSLOView(PolicyAccessView): | ||||
|         method_response = self.check_saml_request() | ||||
|         if method_response: | ||||
|             return method_response | ||||
|         planner = FlowPlanner(self.flow) | ||||
|         planner.allow_empty_flows = True | ||||
|         plan = planner.plan( | ||||
|             request, | ||||
|             { | ||||
|                 PLAN_CONTEXT_APPLICATION: self.application, | ||||
|             }, | ||||
|         ) | ||||
|         plan.insert_stage(in_memory_stage(SessionEndStage)) | ||||
|         request.session[SESSION_KEY_PLAN] = plan | ||||
|         return redirect_with_qs( | ||||
|             "authentik_core:if-flow", | ||||
|             self.request.GET, | ||||
|             flow_slug=self.flow.slug, | ||||
|         return redirect( | ||||
|             "authentik_core:if-session-end", | ||||
|             application_slug=self.kwargs["application_slug"], | ||||
|         ) | ||||
|  | ||||
|     def post(self, request: HttpRequest, application_slug: str) -> HttpResponse: | ||||
|  | ||||
| @ -26,7 +26,6 @@ class SCIMProviderSerializer(ProviderSerializer): | ||||
|             "verbose_name_plural", | ||||
|             "meta_model_name", | ||||
|             "url", | ||||
|             "verify_certificates", | ||||
|             "token", | ||||
|             "exclude_users_service_account", | ||||
|             "filter_group", | ||||
|  | ||||
| @ -42,7 +42,6 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"]( | ||||
|     def __init__(self, provider: SCIMProvider): | ||||
|         super().__init__(provider) | ||||
|         self._session = get_http_session() | ||||
|         self._session.verify = provider.verify_certificates | ||||
|         self.provider = provider | ||||
|         # Remove trailing slashes as we assume the URL doesn't have any | ||||
|         base_url = provider.url | ||||
|  | ||||
| @ -2,10 +2,9 @@ | ||||
|  | ||||
| from itertools import batched | ||||
|  | ||||
| from django.db import transaction | ||||
| from pydantic import ValidationError | ||||
| from pydanticscim.group import GroupMember | ||||
| from pydanticscim.responses import PatchOp | ||||
| from pydanticscim.responses import PatchOp, PatchOperation | ||||
|  | ||||
| from authentik.core.models import Group | ||||
| from authentik.lib.sync.mapper import PropertyMappingManager | ||||
| @ -20,7 +19,7 @@ from authentik.providers.scim.clients.base import SCIMClient | ||||
| from authentik.providers.scim.clients.exceptions import ( | ||||
|     SCIMRequestException, | ||||
| ) | ||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest | ||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest | ||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||
| from authentik.providers.scim.models import ( | ||||
|     SCIMMapping, | ||||
| @ -105,47 +104,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|             provider=self.provider, group=group, scim_id=scim_id | ||||
|         ) | ||||
|         users = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|         self._patch_add_users(connection, users) | ||||
|         self._patch_add_users(group, users) | ||||
|         return connection | ||||
|  | ||||
|     def update(self, group: Group, connection: SCIMProviderGroup): | ||||
|         """Update existing group""" | ||||
|         scim_group = self.to_schema(group, connection) | ||||
|         scim_group.id = connection.scim_id | ||||
|         try: | ||||
|             if self._config.patch.supported: | ||||
|                 return self._update_patch(group, scim_group, connection) | ||||
|             return self._update_put(group, scim_group, connection) | ||||
|         except NotFoundSyncException: | ||||
|             # Resource missing is handled by self.write, which will re-create the group | ||||
|             raise | ||||
|  | ||||
|     def _update_patch( | ||||
|         self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup | ||||
|     ): | ||||
|         """Update a group via PATCH request""" | ||||
|         # Patch group's attributes instead of replacing it and re-adding users if we can | ||||
|         self._request( | ||||
|             "PATCH", | ||||
|             f"/Groups/{connection.scim_id}", | ||||
|             json=PatchRequest( | ||||
|                 Operations=[ | ||||
|                     PatchOperation( | ||||
|                         op=PatchOp.replace, | ||||
|                         path=None, | ||||
|                         value=scim_group.model_dump(mode="json", exclude_unset=True), | ||||
|                     ) | ||||
|                 ] | ||||
|             ).model_dump( | ||||
|                 mode="json", | ||||
|                 exclude_unset=True, | ||||
|                 exclude_none=True, | ||||
|             ), | ||||
|         ) | ||||
|         return self.patch_compare_users(group) | ||||
|  | ||||
|     def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup): | ||||
|         """Update a group via PUT request""" | ||||
|         try: | ||||
|             self._request( | ||||
|                 "PUT", | ||||
| @ -155,25 +120,33 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|                     exclude_unset=True, | ||||
|                 ), | ||||
|             ) | ||||
|             return self.patch_compare_users(group) | ||||
|             users = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|             return self._patch_add_users(group, users) | ||||
|         except NotFoundSyncException: | ||||
|             # Resource missing is handled by self.write, which will re-create the group | ||||
|             raise | ||||
|         except (SCIMRequestException, ObjectExistsSyncException): | ||||
|             # Some providers don't support PUT on groups, so this is mainly a fix for the initial | ||||
|             # sync, send patch add requests for all the users the group currently has | ||||
|             return self._update_patch(group, scim_group, connection) | ||||
|             users = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|             self._patch_add_users(group, users) | ||||
|             # Also update the group name | ||||
|             return self._patch( | ||||
|                 scim_group.id, | ||||
|                 PatchOperation( | ||||
|                     op=PatchOp.replace, | ||||
|                     path="displayName", | ||||
|                     value=scim_group.displayName, | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|     def update_group(self, group: Group, action: Direction, users_set: set[int]): | ||||
|         """Update a group, either using PUT to replace it or PATCH if supported""" | ||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||
|         if not scim_group: | ||||
|             self.logger.warning( | ||||
|                 "could not sync group membership, group does not exist", group=group | ||||
|             ) | ||||
|             return | ||||
|         if self._config.patch.supported: | ||||
|             if action == Direction.add: | ||||
|                 return self._patch_add_users(scim_group, users_set) | ||||
|                 return self._patch_add_users(group, users_set) | ||||
|             if action == Direction.remove: | ||||
|                 return self._patch_remove_users(scim_group, users_set) | ||||
|                 return self._patch_remove_users(group, users_set) | ||||
|         try: | ||||
|             return self.write(group) | ||||
|         except SCIMRequestException as exc: | ||||
| @ -181,24 +154,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|                 # Assume that provider does not support PUT and also doesn't support | ||||
|                 # ServiceProviderConfig, so try PATCH as a fallback | ||||
|                 if action == Direction.add: | ||||
|                     return self._patch_add_users(scim_group, users_set) | ||||
|                     return self._patch_add_users(group, users_set) | ||||
|                 if action == Direction.remove: | ||||
|                     return self._patch_remove_users(scim_group, users_set) | ||||
|                     return self._patch_remove_users(group, users_set) | ||||
|             raise exc | ||||
|  | ||||
|     def _patch_chunked( | ||||
|     def _patch( | ||||
|         self, | ||||
|         group_id: str, | ||||
|         *ops: PatchOperation, | ||||
|     ): | ||||
|         """Helper function that chunks patch requests based on the maxOperations attribute. | ||||
|         This is not strictly according to specs but there's nothing in the schema that allows the | ||||
|         us to know what the maximum patch operations per request should be.""" | ||||
|         chunk_size = self._config.bulk.maxOperations | ||||
|         if chunk_size < 1: | ||||
|             chunk_size = len(ops) | ||||
|         if len(ops) < 1: | ||||
|             return | ||||
|         for chunk in batched(ops, chunk_size): | ||||
|             req = PatchRequest(Operations=list(chunk)) | ||||
|             self._request( | ||||
| @ -209,70 +177,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|     @transaction.atomic | ||||
|     def patch_compare_users(self, group: Group): | ||||
|         """Compare users with a SCIM group and add/remove any differences""" | ||||
|         # Get scim group first | ||||
|     def _patch_add_users(self, group: Group, users_set: set[int]): | ||||
|         """Add users in users_set to group""" | ||||
|         if len(users_set) < 1: | ||||
|             return | ||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||
|         if not scim_group: | ||||
|             self.logger.warning( | ||||
|                 "could not sync group membership, group does not exist", group=group | ||||
|             ) | ||||
|             return | ||||
|         # Get a list of all users in the authentik group | ||||
|         raw_users_should = list(group.users.order_by("id").values_list("id", flat=True)) | ||||
|         # Lookup the SCIM IDs of the users | ||||
|         users_should: list[str] = list( | ||||
|             SCIMProviderUser.objects.filter( | ||||
|                 user__pk__in=raw_users_should, provider=self.provider | ||||
|             ).values_list("scim_id", flat=True) | ||||
|         ) | ||||
|         if len(raw_users_should) != len(users_should): | ||||
|             self.logger.warning( | ||||
|                 "User count mismatch, not all users in the group are synced to SCIM yet.", | ||||
|                 group=group, | ||||
|             ) | ||||
|         # Get current group status | ||||
|         current_group = SCIMGroupSchema.model_validate( | ||||
|             self._request("GET", f"/Groups/{scim_group.scim_id}") | ||||
|         ) | ||||
|         users_to_add = [] | ||||
|         users_to_remove = [] | ||||
|         # Check users currently in group and if they shouldn't be in the group and remove them | ||||
|         for user in current_group.members or []: | ||||
|             if user.value not in users_should: | ||||
|                 users_to_remove.append(user.value) | ||||
|         # Check users that should be in the group and add them | ||||
|         for user in users_should: | ||||
|             if len([x for x in current_group.members if x.value == user]) < 1: | ||||
|                 users_to_add.append(user) | ||||
|         # Only send request if we need to make changes | ||||
|         if len(users_to_add) < 1 and len(users_to_remove) < 1: | ||||
|             return | ||||
|         return self._patch_chunked( | ||||
|             scim_group.scim_id, | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
|                     op=PatchOp.add, | ||||
|                     path="members", | ||||
|                     value=[{"value": x}], | ||||
|                 ) | ||||
|                 for x in users_to_add | ||||
|             ], | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
|                     op=PatchOp.remove, | ||||
|                     path="members", | ||||
|                     value=[{"value": x}], | ||||
|                 ) | ||||
|                 for x in users_to_remove | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): | ||||
|         """Add users in users_set to group""" | ||||
|         if len(users_set) < 1: | ||||
|             return | ||||
|         user_ids = list( | ||||
|             SCIMProviderUser.objects.filter( | ||||
|                 user__pk__in=users_set, provider=self.provider | ||||
| @ -280,7 +194,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|         ) | ||||
|         if len(user_ids) < 1: | ||||
|             return | ||||
|         self._patch_chunked( | ||||
|         self._patch( | ||||
|             scim_group.scim_id, | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
| @ -292,10 +206,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): | ||||
|     def _patch_remove_users(self, group: Group, users_set: set[int]): | ||||
|         """Remove users in users_set from group""" | ||||
|         if len(users_set) < 1: | ||||
|             return | ||||
|         scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() | ||||
|         if not scim_group: | ||||
|             self.logger.warning( | ||||
|                 "could not sync group membership, group does not exist", group=group | ||||
|             ) | ||||
|             return | ||||
|         user_ids = list( | ||||
|             SCIMProviderUser.objects.filter( | ||||
|                 user__pk__in=users_set, provider=self.provider | ||||
| @ -303,7 +223,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]): | ||||
|         ) | ||||
|         if len(user_ids) < 1: | ||||
|             return | ||||
|         self._patch_chunked( | ||||
|         self._patch( | ||||
|             scim_group.scim_id, | ||||
|             *[ | ||||
|                 PatchOperation( | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| from pydantic import Field | ||||
| from pydanticscim.group import Group as BaseGroup | ||||
| from pydanticscim.responses import PatchOperation as BasePatchOperation | ||||
| from pydanticscim.responses import PatchRequest as BasePatchRequest | ||||
| from pydanticscim.responses import SCIMError as BaseSCIMError | ||||
| from pydanticscim.service_provider import Bulk as BaseBulk | ||||
| @ -69,12 +68,6 @@ class PatchRequest(BasePatchRequest): | ||||
|     schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) | ||||
|  | ||||
|  | ||||
| class PatchOperation(BasePatchOperation): | ||||
|     """PatchOperation with optional path""" | ||||
|  | ||||
|     path: str | None | ||||
|  | ||||
|  | ||||
| class SCIMError(BaseSCIMError): | ||||
|     """SCIM error with optional status code""" | ||||
|  | ||||
|  | ||||
| @ -1,18 +0,0 @@ | ||||
| # Generated by Django 5.0.9 on 2024-09-19 14:02 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_scim", "0009_alter_scimmapping_options"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="scimprovider", | ||||
|             name="verify_certificates", | ||||
|             field=models.BooleanField(default=True), | ||||
|         ), | ||||
|     ] | ||||
| @ -68,7 +68,6 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider): | ||||
|  | ||||
|     url = models.TextField(help_text=_("Base URL to SCIM requests, usually ends in /v2")) | ||||
|     token = models.TextField(help_text=_("Authentication token")) | ||||
|     verify_certificates = models.BooleanField(default=True) | ||||
|  | ||||
|     property_mappings_group = models.ManyToManyField( | ||||
|         PropertyMapping, | ||||
|  | ||||
| @ -252,118 +252,3 @@ class SCIMMembershipTests(TestCase): | ||||
|                     ], | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|     def test_member_add_save(self): | ||||
|         """Test member add + save""" | ||||
|         config = ServiceProviderConfiguration.default() | ||||
|  | ||||
|         config.patch.supported = True | ||||
|         user_scim_id = generate_id() | ||||
|         group_scim_id = generate_id() | ||||
|         uid = generate_id() | ||||
|         group = Group.objects.create( | ||||
|             name=uid, | ||||
|         ) | ||||
|  | ||||
|         user = User.objects.create(username=generate_id()) | ||||
|  | ||||
|         # Test initial sync of group creation | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://localhost/ServiceProviderConfig", | ||||
|                 json=config.model_dump(), | ||||
|             ) | ||||
|             mocker.post( | ||||
|                 "https://localhost/Users", | ||||
|                 json={ | ||||
|                     "id": user_scim_id, | ||||
|                 }, | ||||
|             ) | ||||
|             mocker.post( | ||||
|                 "https://localhost/Groups", | ||||
|                 json={ | ||||
|                     "id": group_scim_id, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|             self.configure() | ||||
|             sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||
|  | ||||
|             self.assertEqual(mocker.call_count, 6) | ||||
|             self.assertEqual(mocker.request_history[0].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[1].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[2].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[3].method, "POST") | ||||
|             self.assertEqual(mocker.request_history[4].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[5].method, "POST") | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[3].body, | ||||
|                 { | ||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||
|                     "emails": [], | ||||
|                     "active": True, | ||||
|                     "externalId": user.uid, | ||||
|                     "name": {"familyName": " ", "formatted": " ", "givenName": ""}, | ||||
|                     "displayName": "", | ||||
|                     "userName": user.username, | ||||
|                 }, | ||||
|             ) | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[5].body, | ||||
|                 { | ||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||
|                     "externalId": str(group.pk), | ||||
|                     "displayName": group.name, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get( | ||||
|                 "https://localhost/ServiceProviderConfig", | ||||
|                 json=config.model_dump(), | ||||
|             ) | ||||
|             mocker.get( | ||||
|                 f"https://localhost/Groups/{group_scim_id}", | ||||
|                 json={}, | ||||
|             ) | ||||
|             mocker.patch( | ||||
|                 f"https://localhost/Groups/{group_scim_id}", | ||||
|                 json={}, | ||||
|             ) | ||||
|             group.users.add(user) | ||||
|             group.save() | ||||
|             self.assertEqual(mocker.call_count, 5) | ||||
|             self.assertEqual(mocker.request_history[0].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[1].method, "PATCH") | ||||
|             self.assertEqual(mocker.request_history[2].method, "GET") | ||||
|             self.assertEqual(mocker.request_history[3].method, "PATCH") | ||||
|             self.assertEqual(mocker.request_history[4].method, "GET") | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[1].body, | ||||
|                 { | ||||
|                     "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "add", | ||||
|                             "path": "members", | ||||
|                             "value": [{"value": user_scim_id}], | ||||
|                         } | ||||
|                     ], | ||||
|                 }, | ||||
|             ) | ||||
|             self.assertJSONEqual( | ||||
|                 mocker.request_history[3].body, | ||||
|                 { | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "replace", | ||||
|                             "value": { | ||||
|                                 "id": group_scim_id, | ||||
|                                 "displayName": group.name, | ||||
|                                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||
|                                 "externalId": str(group.pk), | ||||
|                             }, | ||||
|                         } | ||||
|                     ] | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
| @ -22,7 +22,7 @@ def create_admin_group(user: User) -> Group: | ||||
|     return group | ||||
|  | ||||
|  | ||||
| def create_recovery_token(user: User, expiry: datetime, generated_from: str) -> tuple[Token, str]: | ||||
| def create_recovery_token(user: User, expiry: datetime, generated_from: str) -> (Token, str): | ||||
|     """Create recovery token and associated link""" | ||||
|     _now = now() | ||||
|     token = Token.objects.create( | ||||
|  | ||||
| @ -41,9 +41,7 @@ class SessionMiddleware(UpstreamSessionMiddleware): | ||||
|             # Since go does not consider localhost with http a secure origin | ||||
|             # we can't set the secure flag. | ||||
|             user_agent = request.META.get("HTTP_USER_AGENT", "") | ||||
|             if user_agent.startswith("goauthentik.io/outpost/") or ( | ||||
|                 "safari" in user_agent.lower() and "chrome" not in user_agent.lower() | ||||
|             ): | ||||
|             if user_agent.startswith("goauthentik.io/outpost/") or "safari" in user_agent.lower(): | ||||
|                 return False | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
| @ -38,7 +38,6 @@ LANGUAGE_COOKIE_NAME = "authentik_language" | ||||
| SESSION_COOKIE_NAME = "authentik_session" | ||||
| SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None) | ||||
| APPEND_SLASH = False | ||||
| X_FRAME_OPTIONS = "SAMEORIGIN" | ||||
|  | ||||
| AUTHENTICATION_BACKENDS = [ | ||||
|     "django.contrib.auth.backends.ModelBackend", | ||||
| @ -91,7 +90,6 @@ TENANT_APPS = [ | ||||
|     "authentik.providers.scim", | ||||
|     "authentik.rbac", | ||||
|     "authentik.recovery", | ||||
|     "authentik.sources.kerberos", | ||||
|     "authentik.sources.ldap", | ||||
|     "authentik.sources.oauth", | ||||
|     "authentik.sources.plex", | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	