Compare commits

..

3 Commits

Author SHA1 Message Date
929e42d3f2 fix other issues
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2024-10-15 14:51:20 +02:00
863958b4d6 make sure we don't break something else
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2024-10-15 14:22:25 +02:00
2249b9307e core: expiring model: don't synchronously delete
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2024-10-15 14:15:24 +02:00
291 changed files with 1993 additions and 24876 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 2024.10.0
current_version = 2024.8.3
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -14,7 +14,7 @@ runs:
run: |
pipx install poetry || true
sudo apt-get update
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext libkrb5-dev krb5-kdc krb5-user krb5-admin-server
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext
- name: Setup python and restore poetry
uses: actions/setup-python@v5
with:

View File

@ -180,7 +180,7 @@ jobs:
uses: ./.github/actions/setup
- name: Setup e2e env (chrome, etc)
run: |
docker compose -f tests/e2e/docker-compose.yml up -d --quiet-pull
docker compose -f tests/e2e/docker-compose.yml up -d
- id: cache-web
uses: actions/cache@v4
with:

View File

@ -6,7 +6,6 @@
"authn",
"entra",
"goauthentik",
"jwe",
"jwks",
"kubernetes",
"oidc",

View File

@ -110,7 +110,7 @@ RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloa
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
apt-get update && \
# Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libpq-dev libkrb5-dev
apt-get install -y --no-install-recommends build-essential pkg-config libpq-dev
RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
--mount=type=bind,target=./poetry.lock,src=./poetry.lock \
@ -141,7 +141,7 @@ WORKDIR /
# We cannot cache this layer otherwise we'll end up with a bigger image
RUN apt-get update && \
# Required for runtime
apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates libkrb5-3 libkadm5clnt-mit12 libkdb5-10 && \
apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates && \
# Required for bootstrap & healtcheck
apt-get install -y --no-install-recommends runit && \
apt-get clean && \
@ -161,7 +161,6 @@ COPY ./tests /tests
COPY ./manage.py /
COPY ./blueprints /blueprints
COPY ./lifecycle/ /lifecycle
COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf
COPY --from=go-builder /go/authentik /bin/authentik
COPY --from=python-deps /ak-root/venv /ak-root/venv
COPY --from=web-builder /work/web/dist/ /web/dist/

View File

@ -18,10 +18,10 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
(.x being the latest patch release for each version)
| Version | Supported |
| --------- | --------- |
| 2024.8.x | ✅ |
| 2024.10.x | ✅ |
| Version | Supported |
| -------- | --------- |
| 2024.6.x | ✅ |
| 2024.8.x | ✅ |
## Reporting a Vulnerability

View File

@ -2,7 +2,7 @@
from os import environ
__version__ = "2024.10.0"
__version__ = "2024.8.3"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -1,33 +0,0 @@
from rest_framework.permissions import IsAdminUser
from rest_framework.viewsets import ReadOnlyModelViewSet
from authentik.admin.models import VersionHistory
from authentik.core.api.utils import ModelSerializer
class VersionHistorySerializer(ModelSerializer):
"""VersionHistory Serializer"""
class Meta:
model = VersionHistory
fields = [
"id",
"timestamp",
"version",
"build",
]
class VersionHistoryViewSet(ReadOnlyModelViewSet):
"""VersionHistory Viewset"""
queryset = VersionHistory.objects.all()
serializer_class = VersionHistorySerializer
permission_classes = [IsAdminUser]
filterset_fields = [
"version",
"build",
]
search_fields = ["version", "build"]
ordering = ["-timestamp"]
pagination_class = None

View File

@ -1,22 +0,0 @@
"""authentik admin models"""
from django.db import models
from django.utils.translation import gettext_lazy as _
class VersionHistory(models.Model):
id = models.BigAutoField(primary_key=True)
timestamp = models.DateTimeField()
version = models.TextField()
build = models.TextField()
class Meta:
managed = False
db_table = "authentik_version_history"
ordering = ("-timestamp",)
verbose_name = _("Version history")
verbose_name_plural = _("Version history")
default_permissions = []
def __str__(self):
return f"{self.version}.{self.build} ({self.timestamp})"

View File

@ -6,7 +6,6 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet
from authentik.admin.api.metrics import AdministrationMetricsViewSet
from authentik.admin.api.system import SystemView
from authentik.admin.api.version import VersionView
from authentik.admin.api.version_history import VersionHistoryViewSet
from authentik.admin.api.workers import WorkerView
api_urlpatterns = [
@ -18,7 +17,6 @@ api_urlpatterns = [
name="admin_metrics",
),
path("admin/version/", VersionView.as_view(), name="admin_version"),
("admin/version/history", VersionHistoryViewSet, "version_history"),
path("admin/workers/", WorkerView.as_view(), name="admin_workers"),
path("admin/system/", SystemView.as_view(), name="admin_system"),
]

View File

@ -51,10 +51,6 @@ from authentik.enterprise.providers.microsoft_entra.models import (
MicrosoftEntraProviderUser,
)
from authentik.enterprise.providers.rac.models import ConnectionToken
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
EndpointDevice,
EndpointDeviceConnection,
)
from authentik.events.logs import LogEvent, capture_logs
from authentik.events.models import SystemTask
from authentik.events.utils import cleanse_dict
@ -123,8 +119,6 @@ def excluded_models() -> list[type[Model]]:
GoogleWorkspaceProviderGroup,
MicrosoftEntraProviderUser,
MicrosoftEntraProviderGroup,
EndpointDevice,
EndpointDeviceConnection,
)

View File

@ -6,45 +6,34 @@ from rest_framework.fields import (
BooleanField,
CharField,
DateTimeField,
IntegerField,
SerializerMethodField,
)
from rest_framework.permissions import IsAuthenticated
from rest_framework.permissions import IsAdminUser, IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from authentik.core.api.utils import MetaNameSerializer
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice
from authentik.rbac.decorators import permission_required
from authentik.stages.authenticator import device_classes, devices_for_user
from authentik.stages.authenticator.models import Device
from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
class DeviceSerializer(MetaNameSerializer):
"""Serializer for Duo authenticator devices"""
pk = CharField()
pk = IntegerField()
name = CharField()
type = SerializerMethodField()
confirmed = BooleanField()
created = DateTimeField(read_only=True)
last_updated = DateTimeField(read_only=True)
last_used = DateTimeField(read_only=True, allow_null=True)
extra_description = SerializerMethodField()
def get_type(self, instance: Device) -> str:
"""Get type of device"""
return instance._meta.label
def get_extra_description(self, instance: Device) -> str:
"""Get extra description"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.description
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return ""
class DeviceViewSet(ViewSet):
"""Viewset for authenticator devices"""
@ -63,7 +52,7 @@ class AdminDeviceViewSet(ViewSet):
"""Viewset for authenticator devices"""
serializer_class = DeviceSerializer
permission_classes = []
permission_classes = [IsAdminUser]
def get_devices(self, **kwargs):
"""Get all devices in all child classes"""
@ -81,10 +70,6 @@ class AdminDeviceViewSet(ViewSet):
],
responses={200: DeviceSerializer(many=True)},
)
@permission_required(
None,
[f"{model._meta.app_label}.view_{model._meta.model_name}" for model in device_classes()],
)
def list(self, request: Request) -> Response:
"""Get all devices for current user"""
kwargs = {}

View File

@ -4,7 +4,6 @@ import code
import platform
import sys
import traceback
from pprint import pprint
from django.apps import apps
from django.core.management.base import BaseCommand
@ -35,9 +34,7 @@ class Command(BaseCommand):
def get_namespace(self):
"""Prepare namespace with all models"""
namespace = {
"pprint": pprint,
}
namespace = {}
# Gather Django models and constants from each app
for app in apps.get_app_configs():

View File

@ -330,13 +330,11 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
"""superuser == staff user"""
return self.is_superuser # type: ignore
def set_password(self, raw_password, signal=True, sender=None):
def set_password(self, raw_password, signal=True):
if self.pk and signal:
from authentik.core.signals import password_changed
if not sender:
sender = self
password_changed.send(sender=sender, user=self, password=raw_password)
password_changed.send(sender=self, user=self, password=raw_password)
self.password_change_date = now()
return super().set_password(raw_password)
@ -804,12 +802,25 @@ class ExpiringModel(models.Model):
return self.delete(*args, **kwargs)
@classmethod
def filter_not_expired(cls, **kwargs) -> QuerySet["Token"]:
def _not_expired_filter(cls):
return Q(expires__gt=now(), expiring=True) | Q(expiring=False)
@classmethod
def filter_not_expired(cls, delete_expired=False, **kwargs) -> QuerySet["ExpiringModel"]:
"""Filer for tokens which are not expired yet or are not expiring,
and match filters in `kwargs`"""
for obj in cls.objects.filter(**kwargs).filter(Q(expires__lt=now(), expiring=True)):
obj.delete()
return cls.objects.filter(**kwargs)
if delete_expired:
cls.delete_expired(**kwargs)
return cls.objects.filter(cls._not_expired_filter()).filter(**kwargs)
@classmethod
def delete_expired(cls, **kwargs) -> int:
objects = cls.objects.all().exclude(cls._not_expired_filter()).filter(**kwargs)
amount = 0
for obj in objects:
obj.expire_action()
amount += 1
return amount
@property
def is_expired(self) -> bool:

View File

@ -1,9 +1,11 @@
"""Source decision helper"""
from enum import Enum
from typing import Any
from django.contrib import messages
from django.db import IntegrityError, transaction
from django.db.models.query_utils import Q
from django.http import HttpRequest, HttpResponse
from django.shortcuts import redirect
from django.urls import reverse
@ -14,11 +16,12 @@ from authentik.core.models import (
Group,
GroupSourceConnection,
Source,
SourceGroupMatchingModes,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
from authentik.core.sources.mapper import SourceMapper
from authentik.core.sources.matcher import Action, SourceMatcher
from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION,
PostSourceStage,
@ -51,6 +54,16 @@ SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token"
PLAN_CONTEXT_SOURCE_GROUPS = "source_groups"
class Action(Enum):
"""Actions that can be decided based on the request
and source settings"""
LINK = "link"
AUTH = "auth"
ENROLL = "enroll"
DENY = "deny"
class MessageStage(StageView):
"""Show a pre-configured message after the flow is done"""
@ -73,7 +86,6 @@ class SourceFlowManager:
source: Source
mapper: SourceMapper
matcher: SourceMatcher
request: HttpRequest
identifier: str
@ -96,9 +108,6 @@ class SourceFlowManager:
) -> None:
self.source = source
self.mapper = SourceMapper(self.source)
self.matcher = SourceMatcher(
self.source, self.user_connection_type, self.group_connection_type
)
self.request = request
self.identifier = identifier
self.user_info = user_info
@ -122,19 +131,66 @@ class SourceFlowManager:
def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911
"""decide which action should be taken"""
new_connection = self.user_connection_type(source=self.source, identifier=self.identifier)
# When request is authenticated, always link
if self.request.user.is_authenticated:
new_connection = self.user_connection_type(
source=self.source, identifier=self.identifier
)
new_connection.user = self.request.user
new_connection = self.update_user_connection(new_connection, **kwargs)
return Action.LINK, new_connection
action, connection = self.matcher.get_user_action(self.identifier, self.user_properties)
if connection:
connection = self.update_user_connection(connection, **kwargs)
return action, connection
existing_connections = self.user_connection_type.objects.filter(
source=self.source, identifier=self.identifier
)
if existing_connections.exists():
connection = existing_connections.first()
return Action.AUTH, self.update_user_connection(connection, **kwargs)
# No connection exists, but we match on identifier, so enroll
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
# Check for existing users with matching attributes
query = Q()
# Either query existing user based on email or username
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.EMAIL_DENY,
]:
if not self.user_properties.get("email", None):
self._logger.warning("Refusing to use none email")
return Action.DENY, None
query = Q(email__exact=self.user_properties.get("email", None))
if self.source.user_matching_mode in [
SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY,
]:
if not self.user_properties.get("username", None):
self._logger.warning("Refusing to use none username")
return Action.DENY, None
query = Q(username__exact=self.user_properties.get("username", None))
self._logger.debug("trying to link with existing user", query=query)
matching_users = User.objects.filter(query)
# No matching users, always enroll
if not matching_users.exists():
self._logger.debug("no matching users found, enrolling")
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
user = matching_users.first()
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.USERNAME_LINK,
]:
new_connection.user = user
new_connection = self.update_user_connection(new_connection, **kwargs)
return Action.LINK, new_connection
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_DENY,
SourceUserMatchingModes.USERNAME_DENY,
]:
self._logger.info("denying source because user exists", user=user)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def update_user_connection(
self, connection: UserSourceConnection, **kwargs
@ -272,6 +328,7 @@ class SourceFlowManager:
connection: UserSourceConnection,
) -> HttpResponse:
"""Login user and redirect."""
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
return self._prepare_flow(
self.source.authentication_flow,
connection,
@ -285,11 +342,7 @@ class SourceFlowManager:
),
)
],
**{
PLAN_CONTEXT_PENDING_USER: connection.user,
PLAN_CONTEXT_PROMPT: delete_none_values(self.user_properties),
PLAN_CONTEXT_USER_PATH: self.source.get_user_path(),
},
**flow_kwargs,
)
def handle_existing_link(
@ -355,16 +408,74 @@ class SourceFlowManager:
class GroupUpdateStage(StageView):
"""Dynamically injected stage which updates the user after enrollment/authentication."""
def get_action(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, GroupSourceConnection | None]:
"""decide which action should be taken"""
new_connection = self.group_connection_type(source=self.source, identifier=group_id)
existing_connections = self.group_connection_type.objects.filter(
source=self.source, identifier=group_id
)
if existing_connections.exists():
return Action.LINK, existing_connections.first()
# No connection exists, but we match on identifier, so enroll
if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, new_connection
# Check for existing groups with matching attributes
query = Q()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
SourceGroupMatchingModes.NAME_DENY,
]:
if not group_properties.get("name", None):
LOGGER.warning(
"Refusing to use none group name", source=self.source, group_id=group_id
)
return Action.DENY, None
query = Q(name__exact=group_properties.get("name"))
LOGGER.debug(
"trying to link with existing group", source=self.source, query=query, group_id=group_id
)
matching_groups = Group.objects.filter(query)
# No matching groups, always enroll
if not matching_groups.exists():
LOGGER.debug(
"no matching groups found, enrolling", source=self.source, group_id=group_id
)
return Action.ENROLL, new_connection
group = matching_groups.first()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
]:
new_connection.group = group
return Action.LINK, new_connection
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_DENY,
]:
LOGGER.info(
"denying source because group exists",
source=self.source,
group=group,
group_id=group_id,
)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def handle_group(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> Group | None:
action, connection = self.matcher.get_group_action(group_id, group_properties)
action, connection = self.get_action(group_id, group_properties)
if action == Action.ENROLL:
group = Group.objects.create(**group_properties)
connection.group = group
connection.save()
return group
elif action in (Action.LINK, Action.AUTH):
elif action == Action.LINK:
group = connection.group
group.update_attributes(group_properties)
connection.save()
@ -378,7 +489,6 @@ class GroupUpdateStage(StageView):
self.group_connection_type: GroupSourceConnection = (
self.executor.current_stage.group_connection_type
)
self.matcher = SourceMatcher(self.source, None, self.group_connection_type)
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
PLAN_CONTEXT_SOURCE_GROUPS

View File

@ -1,152 +0,0 @@
"""Source user and group matching"""
from dataclasses import dataclass
from enum import Enum
from typing import Any
from django.db.models import Q
from structlog import get_logger
from authentik.core.models import (
Group,
GroupSourceConnection,
Source,
SourceGroupMatchingModes,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
class Action(Enum):
"""Actions that can be decided based on the request and source settings"""
LINK = "link"
AUTH = "auth"
ENROLL = "enroll"
DENY = "deny"
@dataclass
class MatchableProperty:
property: str
link_mode: SourceUserMatchingModes | SourceGroupMatchingModes
deny_mode: SourceUserMatchingModes | SourceGroupMatchingModes
class SourceMatcher:
def __init__(
self,
source: Source,
user_connection_type: type[UserSourceConnection],
group_connection_type: type[GroupSourceConnection],
):
self.source = source
self.user_connection_type = user_connection_type
self.group_connection_type = group_connection_type
self._logger = get_logger().bind(source=self.source)
def get_action(
self,
object_type: type[User | Group],
matchable_properties: list[MatchableProperty],
identifier: str,
properties: dict[str, Any | dict[str, Any]],
) -> tuple[Action, UserSourceConnection | GroupSourceConnection | None]:
connection_type = None
matching_mode = None
identifier_matching_mode = None
if object_type == User:
connection_type = self.user_connection_type
matching_mode = self.source.user_matching_mode
identifier_matching_mode = SourceUserMatchingModes.IDENTIFIER
if object_type == Group:
connection_type = self.group_connection_type
matching_mode = self.source.group_matching_mode
identifier_matching_mode = SourceGroupMatchingModes.IDENTIFIER
if not connection_type or not matching_mode or not identifier_matching_mode:
return Action.DENY, None
new_connection = connection_type(source=self.source, identifier=identifier)
existing_connections = connection_type.objects.filter(
source=self.source, identifier=identifier
)
if existing_connections.exists():
return Action.AUTH, existing_connections.first()
# No connection exists, but we match on identifier, so enroll
if matching_mode == identifier_matching_mode:
# We don't save the connection here cause it doesn't have a user/group assigned yet
return Action.ENROLL, new_connection
# Check for existing users with matching attributes
query = Q()
for matchable_property in matchable_properties:
property = matchable_property.property
if matching_mode in [matchable_property.link_mode, matchable_property.deny_mode]:
if not properties.get(property, None):
self._logger.warning(
"Refusing to use none property", identifier=identifier, property=property
)
return Action.DENY, None
query_args = {
f"{property}__exact": properties[property],
}
query = Q(**query_args)
self._logger.debug(
"Trying to link with existing object", query=query, identifier=identifier
)
matching_objects = object_type.objects.filter(query)
# Not matching objects, always enroll
if not matching_objects.exists():
self._logger.debug("No matching objects found, enrolling")
return Action.ENROLL, new_connection
obj = matching_objects.first()
if matching_mode in [mp.link_mode for mp in matchable_properties]:
attr = None
if object_type == User:
attr = "user"
if object_type == Group:
attr = "group"
setattr(new_connection, attr, obj)
return Action.LINK, new_connection
if matching_mode in [mp.deny_mode for mp in matchable_properties]:
self._logger.info("Denying source because object exists", obj=obj)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def get_user_action(
self, identifier: str, properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, UserSourceConnection | None]:
return self.get_action(
User,
[
MatchableProperty(
"username",
SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY,
),
MatchableProperty(
"email", SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_DENY
),
],
identifier,
properties,
)
def get_group_action(
self, identifier: str, properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, GroupSourceConnection | None]:
return self.get_action(
Group,
[
MatchableProperty(
"name", SourceGroupMatchingModes.NAME_LINK, SourceGroupMatchingModes.NAME_DENY
),
],
identifier,
properties,
)

View File

@ -30,12 +30,7 @@ def clean_expired_models(self: SystemTask):
messages = []
for cls in ExpiringModel.__subclasses__():
cls: ExpiringModel
objects = (
cls.objects.all().exclude(expiring=False).exclude(expiring=True, expires__gt=now())
)
amount = objects.count()
for obj in objects:
obj.expire_action()
amount = cls.delete_expired()
LOGGER.debug("Expired models", model=cls, amount=amount)
messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}")
# Special case

View File

@ -1,59 +0,0 @@
"""Test Devices API"""
from json import loads
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user, create_test_user
class TestDevicesAPI(APITestCase):
"""Test applications API"""
def setUp(self) -> None:
self.admin = create_test_admin_user()
self.user1 = create_test_user()
self.device1 = self.user1.staticdevice_set.create()
self.user2 = create_test_user()
self.device2 = self.user2.staticdevice_set.create()
def test_user_api(self):
"""Test user API"""
self.client.force_login(self.user1)
response = self.client.get(
reverse(
"authentik_api:device-list",
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content.decode())
self.assertEqual(len(body), 1)
self.assertEqual(body[0]["pk"], str(self.device1.pk))
def test_user_api_as_admin(self):
"""Test user API"""
self.client.force_login(self.admin)
response = self.client.get(
reverse(
"authentik_api:device-list",
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content.decode())
self.assertEqual(len(body), 0)
def test_admin_api(self):
"""Test admin API"""
self.client.force_login(self.admin)
response = self.client.get(
reverse(
"authentik_api:admin-device-list",
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content.decode())
self.assertEqual(len(body), 2)
self.assertEqual(
{body[0]["pk"], body[1]["pk"]}, {str(self.device1.pk), str(self.device2.pk)}
)

View File

@ -5,6 +5,7 @@ from channels.sessions import CookieMiddleware
from django.conf import settings
from django.contrib.auth.decorators import login_required
from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from authentik.core.api.applications import ApplicationViewSet
from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet
@ -43,19 +44,19 @@ urlpatterns = [
# Interfaces
path(
"if/admin/",
BrandDefaultRedirectView.as_view(template_name="if/admin.html"),
ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/admin.html")),
name="if-admin",
),
path(
"if/user/",
BrandDefaultRedirectView.as_view(template_name="if/user.html"),
ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/user.html")),
name="if-user",
),
path(
"if/flow/<slug:flow_slug>/",
# FIXME: move this url to the flows app...also will cause all
# of the reverse calls to be adjusted
FlowInterfaceView.as_view(),
ensure_csrf_cookie(FlowInterfaceView.as_view()),
name="if-flow",
),
# Fallback for WS

View File

@ -3,6 +3,7 @@
from channels.auth import AuthMiddleware
from channels.sessions import CookieMiddleware
from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
@ -18,12 +19,12 @@ from authentik.root.middleware import ChannelsLoggingMiddleware
urlpatterns = [
path(
"application/rac/<slug:app>/<uuid:endpoint>/",
RACStartView.as_view(),
ensure_csrf_cookie(RACStartView.as_view()),
name="start",
),
path(
"if/rac/<str:token>/",
RACInterface.as_view(),
ensure_csrf_cookie(RACInterface.as_view()),
name="if-rac",
),
]

View File

@ -17,7 +17,6 @@ TENANT_APPS = [
"authentik.enterprise.providers.google_workspace",
"authentik.enterprise.providers.microsoft_entra",
"authentik.enterprise.providers.rac",
"authentik.enterprise.stages.authenticator_endpoint_gdtc",
"authentik.enterprise.stages.source",
]

View File

@ -1,82 +0,0 @@
"""AuthenticatorEndpointGDTCStage API Views"""
from django_filters.rest_framework.backends import DjangoFilterBackend
from rest_framework import mixins
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAdminUser
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.authorization import OwnerFilter, OwnerPermissions
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
AuthenticatorEndpointGDTCStage,
EndpointDevice,
)
from authentik.flows.api.stages import StageSerializer
LOGGER = get_logger()
class AuthenticatorEndpointGDTCStageSerializer(EnterpriseRequiredMixin, StageSerializer):
"""AuthenticatorEndpointGDTCStage Serializer"""
class Meta:
model = AuthenticatorEndpointGDTCStage
fields = StageSerializer.Meta.fields + [
"configure_flow",
"friendly_name",
"credentials",
]
class AuthenticatorEndpointGDTCStageViewSet(UsedByMixin, ModelViewSet):
"""AuthenticatorEndpointGDTCStage Viewset"""
queryset = AuthenticatorEndpointGDTCStage.objects.all()
serializer_class = AuthenticatorEndpointGDTCStageSerializer
filterset_fields = [
"name",
"configure_flow",
]
search_fields = ["name"]
ordering = ["name"]
class EndpointDeviceSerializer(ModelSerializer):
"""Serializer for Endpoint authenticator devices"""
class Meta:
model = EndpointDevice
fields = ["pk", "name"]
depth = 2
class EndpointDeviceViewSet(
mixins.RetrieveModelMixin,
mixins.ListModelMixin,
UsedByMixin,
GenericViewSet,
):
"""Viewset for Endpoint authenticator devices"""
queryset = EndpointDevice.objects.all()
serializer_class = EndpointDeviceSerializer
search_fields = ["name"]
filterset_fields = ["name"]
ordering = ["name"]
permission_classes = [OwnerPermissions]
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
class EndpointAdminDeviceViewSet(ModelViewSet):
"""Viewset for Endpoint authenticator devices (for admins)"""
permission_classes = [IsAdminUser]
queryset = EndpointDevice.objects.all()
serializer_class = EndpointDeviceSerializer
search_fields = ["name"]
filterset_fields = ["name"]
ordering = ["name"]

View File

@ -1,13 +0,0 @@
"""authentik Endpoint app config"""
from authentik.enterprise.apps import EnterpriseConfig
class AuthentikStageAuthenticatorEndpointConfig(EnterpriseConfig):
"""authentik endpoint config"""
name = "authentik.enterprise.stages.authenticator_endpoint_gdtc"
label = "authentik_stages_authenticator_endpoint_gdtc"
verbose_name = "authentik Enterprise.Stages.Authenticator.Endpoint GDTC"
default = True
mountpoint = "endpoint/gdtc/"

View File

@ -1,115 +0,0 @@
# Generated by Django 5.0.9 on 2024-10-22 11:40
import django.db.models.deletion
import uuid
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_flows", "0027_auto_20231028_1424"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="AuthenticatorEndpointGDTCStage",
fields=[
(
"stage_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_flows.stage",
),
),
("friendly_name", models.TextField(null=True)),
("credentials", models.JSONField()),
(
"configure_flow",
models.ForeignKey(
blank=True,
help_text="Flow used by an authenticated user to configure this Stage. If empty, user will not be able to configure this stage.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="authentik_flows.flow",
),
),
],
options={
"verbose_name": "Endpoint Authenticator Google Device Trust Connector Stage",
"verbose_name_plural": "Endpoint Authenticator Google Device Trust Connector Stages",
},
bases=("authentik_flows.stage", models.Model),
),
migrations.CreateModel(
name="EndpointDevice",
fields=[
("created", models.DateTimeField(auto_now_add=True)),
("last_updated", models.DateTimeField(auto_now=True)),
(
"name",
models.CharField(
help_text="The human-readable name of this device.", max_length=64
),
),
(
"confirmed",
models.BooleanField(default=True, help_text="Is this device ready for use?"),
),
("last_used", models.DateTimeField(null=True)),
("uuid", models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
(
"host_identifier",
models.TextField(
help_text="A unique identifier for the endpoint device, usually the device serial number",
unique=True,
),
),
("data", models.JSONField()),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
),
],
options={
"verbose_name": "Endpoint Device",
"verbose_name_plural": "Endpoint Devices",
},
),
migrations.CreateModel(
name="EndpointDeviceConnection",
fields=[
(
"id",
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
),
),
("attributes", models.JSONField()),
(
"device",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_stages_authenticator_endpoint_gdtc.endpointdevice",
),
),
(
"stage",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_stages_authenticator_endpoint_gdtc.authenticatorendpointgdtcstage",
),
),
],
),
]

View File

@ -1,101 +0,0 @@
"""Endpoint stage"""
from uuid import uuid4
from django.contrib.auth import get_user_model
from django.db import models
from django.utils.translation import gettext_lazy as _
from google.oauth2.service_account import Credentials
from rest_framework.serializers import BaseSerializer, Serializer
from authentik.core.types import UserSettingSerializer
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
from authentik.flows.stage import StageView
from authentik.lib.models import SerializerModel
from authentik.stages.authenticator.models import Device
class AuthenticatorEndpointGDTCStage(ConfigurableStage, FriendlyNamedStage, Stage):
"""Setup Google Chrome Device-trust connection"""
credentials = models.JSONField()
def google_credentials(self):
return {
"credentials": Credentials.from_service_account_info(
self.credentials, scopes=["https://www.googleapis.com/auth/verifiedaccess"]
),
}
@property
def serializer(self) -> type[BaseSerializer]:
from authentik.enterprise.stages.authenticator_endpoint_gdtc.api import (
AuthenticatorEndpointGDTCStageSerializer,
)
return AuthenticatorEndpointGDTCStageSerializer
@property
def view(self) -> type[StageView]:
from authentik.enterprise.stages.authenticator_endpoint_gdtc.stage import (
AuthenticatorEndpointStageView,
)
return AuthenticatorEndpointStageView
@property
def component(self) -> str:
return "ak-stage-authenticator-endpoint-gdtc-form"
def ui_user_settings(self) -> UserSettingSerializer | None:
return UserSettingSerializer(
data={
"title": self.friendly_name or str(self._meta.verbose_name),
"component": "ak-user-settings-authenticator-endpoint",
}
)
def __str__(self) -> str:
return f"Endpoint Authenticator Google Device Trust Connector Stage {self.name}"
class Meta:
verbose_name = _("Endpoint Authenticator Google Device Trust Connector Stage")
verbose_name_plural = _("Endpoint Authenticator Google Device Trust Connector Stages")
class EndpointDevice(SerializerModel, Device):
"""Endpoint Device for a single user"""
uuid = models.UUIDField(primary_key=True, default=uuid4)
host_identifier = models.TextField(
unique=True,
help_text="A unique identifier for the endpoint device, usually the device serial number",
)
user = models.ForeignKey(get_user_model(), on_delete=models.CASCADE)
data = models.JSONField()
@property
def serializer(self) -> Serializer:
from authentik.enterprise.stages.authenticator_endpoint_gdtc.api import (
EndpointDeviceSerializer,
)
return EndpointDeviceSerializer
def __str__(self):
return str(self.name) or str(self.user_id)
class Meta:
verbose_name = _("Endpoint Device")
verbose_name_plural = _("Endpoint Devices")
class EndpointDeviceConnection(models.Model):
device = models.ForeignKey(EndpointDevice, on_delete=models.CASCADE)
stage = models.ForeignKey(AuthenticatorEndpointGDTCStage, on_delete=models.CASCADE)
attributes = models.JSONField()
def __str__(self) -> str:
return f"Endpoint device connection {self.device_id} to {self.stage_id}"

View File

@ -1,32 +0,0 @@
from django.http import HttpResponse
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from authentik.flows.challenge import (
Challenge,
ChallengeResponse,
FrameChallenge,
FrameChallengeResponse,
)
from authentik.flows.stage import ChallengeStageView
class AuthenticatorEndpointStageView(ChallengeStageView):
"""Endpoint stage"""
response_class = FrameChallengeResponse
def get_challenge(self, *args, **kwargs) -> Challenge:
return FrameChallenge(
data={
"component": "xak-flow-frame",
"url": self.request.build_absolute_uri(
reverse("authentik_stages_authenticator_endpoint_gdtc:chrome")
),
"loading_overlay": True,
"loading_text": _("Verifying your browser..."),
}
)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
return self.executor.stage_ok()

View File

@ -1,9 +0,0 @@
<html>
<script>
window.parent.postMessage({
message: "submit",
source: "goauthentik.io",
context: "flow-executor"
});
</script>
</html>

View File

@ -1,26 +0,0 @@
"""API URLs"""
from django.urls import path
from authentik.enterprise.stages.authenticator_endpoint_gdtc.api import (
AuthenticatorEndpointGDTCStageViewSet,
EndpointAdminDeviceViewSet,
EndpointDeviceViewSet,
)
from authentik.enterprise.stages.authenticator_endpoint_gdtc.views.dtc import (
GoogleChromeDeviceTrustConnector,
)
urlpatterns = [
path("chrome/", GoogleChromeDeviceTrustConnector.as_view(), name="chrome"),
]
api_urlpatterns = [
("authenticators/endpoint", EndpointDeviceViewSet),
(
"authenticators/admin/endpoint",
EndpointAdminDeviceViewSet,
"admin-endpointdevice",
),
("stages/authenticator/endpoint_gdtc", AuthenticatorEndpointGDTCStageViewSet),
]

View File

@ -1,84 +0,0 @@
from json import dumps, loads
from typing import Any
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.template.response import TemplateResponse
from django.urls import reverse
from django.views import View
from googleapiclient.discovery import build
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
AuthenticatorEndpointGDTCStage,
EndpointDevice,
EndpointDeviceConnection,
)
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
# Header we get from chrome that initiates verified access
HEADER_DEVICE_TRUST = "X-Device-Trust"
# Header we send to the client with the challenge
HEADER_ACCESS_CHALLENGE = "X-Verified-Access-Challenge"
# Header we get back from the client that we verify with google
HEADER_ACCESS_CHALLENGE_RESPONSE = "X-Verified-Access-Challenge-Response"
# Header value for x-device-trust that initiates the flow
DEVICE_TRUST_VERIFIED_ACCESS = "VerifiedAccess"
class GoogleChromeDeviceTrustConnector(View):
"""Google Chrome Device-trust connector based endpoint authenticator"""
def get_flow_plan(self) -> FlowPlan:
flow_plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]
return flow_plan
def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None:
super().setup(request, *args, **kwargs)
stage: AuthenticatorEndpointGDTCStage = self.get_flow_plan().bindings[0].stage
self.google_client = build(
"verifiedaccess",
"v2",
cache_discovery=False,
**stage.google_credentials(),
)
def get(self, request: HttpRequest) -> HttpResponse:
x_device_trust = request.headers.get(HEADER_DEVICE_TRUST)
x_access_challenge_response = request.headers.get(HEADER_ACCESS_CHALLENGE_RESPONSE)
if x_device_trust == "VerifiedAccess" and x_access_challenge_response is None:
challenge = self.google_client.challenge().generate().execute()
res = HttpResponseRedirect(
self.request.build_absolute_uri(
reverse("authentik_stages_authenticator_endpoint_gdtc:chrome")
)
)
res[HEADER_ACCESS_CHALLENGE] = dumps(challenge)
return res
if x_access_challenge_response:
response = (
self.google_client.challenge()
.verify(body=loads(x_access_challenge_response))
.execute()
)
# Remove deprecated string representation of deviceSignals
response.pop("deviceSignal", None)
flow_plan: FlowPlan = self.get_flow_plan()
device, _ = EndpointDevice.objects.update_or_create(
host_identifier=response["deviceSignals"]["serialNumber"],
user=flow_plan.context.get(PLAN_CONTEXT_PENDING_USER),
defaults={"name": response["deviceSignals"]["hostname"], "data": response},
)
EndpointDeviceConnection.objects.update_or_create(
device=device,
stage=flow_plan.bindings[0].stage,
defaults={
"attributes": response,
},
)
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, "trusted_endpoint")
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault("endpoints", [])
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS]["endpoints"].append(response)
request.session[SESSION_KEY_PLAN] = flow_plan
return TemplateResponse(request, "stages/authenticator_endpoint/google_chrome_dtc.html")

View File

@ -1,16 +1,13 @@
"""authentik events signal listener"""
from importlib import import_module
from typing import Any
from django.conf import settings
from django.contrib.auth.signals import user_logged_in, user_logged_out
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from rest_framework.request import Request
from authentik.core.models import AuthenticatedSession, User
from authentik.core.models import User
from authentik.core.signals import login_failed, password_changed
from authentik.events.apps import SYSTEM_TASK_STATUS
from authentik.events.models import Event, EventAction, SystemTask
@ -26,7 +23,6 @@ from authentik.stages.user_write.signals import user_write
from authentik.tenants.utils import get_current_tenant
SESSION_LOGIN_EVENT = "login_event"
_session_engine = import_module(settings.SESSION_ENGINE)
@receiver(user_logged_in)
@ -47,20 +43,11 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
kwargs[PLAN_CONTEXT_OUTPOST] = flow_plan.context[PLAN_CONTEXT_OUTPOST]
event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user)
request.session[SESSION_LOGIN_EVENT] = event
request.session.save()
def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | None) -> Event | None:
def get_login_event(request: HttpRequest) -> Event | None:
"""Wrapper to get login event that can be mocked in tests"""
session = None
if not request_or_session:
return None
if isinstance(request_or_session, HttpRequest | Request):
session = request_or_session.session
if isinstance(request_or_session, AuthenticatedSession):
SessionStore = _session_engine.SessionStore
session = SessionStore(request_or_session.session_key)
return session.get(SESSION_LOGIN_EVENT, None)
return request.session.get(SESSION_LOGIN_EVENT, None)
@receiver(user_logged_out)

View File

@ -8,7 +8,7 @@ from uuid import UUID
from django.core.serializers.json import DjangoJSONEncoder
from django.db import models
from django.http import JsonResponse
from rest_framework.fields import BooleanField, CharField, ChoiceField, DictField
from rest_framework.fields import CharField, ChoiceField, DictField
from rest_framework.request import Request
from authentik.core.api.utils import PassiveSerializer
@ -160,20 +160,6 @@ class AutoSubmitChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-autosubmit")
class FrameChallenge(Challenge):
"""Challenge type to render a frame"""
component = CharField(default="xak-flow-frame")
url = CharField()
loading_overlay = BooleanField(default=False)
loading_text = CharField()
class FrameChallengeResponse(ChallengeResponse):
component = CharField(default="xak-flow-frame")
class DataclassEncoder(DjangoJSONEncoder):
"""Convert any dataclass to json"""

View File

@ -46,7 +46,6 @@ class TestFlowInspector(APITestCase):
res.content,
{
"allow_show_password": False,
"captcha_stage": None,
"component": "ak-stage-identification",
"flow_info": {
"background": flow.background_url,

View File

@ -105,10 +105,6 @@ ldap:
tls:
ciphers: null
sources:
kerberos:
task_timeout_hours: 2
reputation:
expiry: 86400

View File

@ -21,14 +21,7 @@ class DebugSession(Session):
def send(self, req: PreparedRequest, *args, **kwargs):
request_id = str(uuid4())
LOGGER.debug(
"HTTP request sent",
uid=request_id,
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
)
LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
resp = super().send(req, *args, **kwargs)
LOGGER.debug(
"HTTP response received",

View File

@ -9,7 +9,7 @@ from uuid import uuid4
from dacite.core import from_dict
from django.contrib.auth.models import Permission
from django.core.cache import cache
from django.db import IntegrityError, models, transaction
from django.db import models, transaction
from django.db.models.base import Model
from django.utils.translation import gettext_lazy as _
from guardian.models import UserObjectPermission
@ -380,26 +380,22 @@ class Outpost(SerializerModel, ManagedModel):
"""Get/create token for auto-generated user"""
managed = f"goauthentik.io/outpost/{self.token_identifier}"
tokens = Token.filter_not_expired(
delete_expired=True,
identifier=self.token_identifier,
intent=TokenIntents.INTENT_API,
managed=managed,
)
if tokens.exists():
return tokens.first()
try:
return Token.objects.create(
user=self.user,
identifier=self.token_identifier,
intent=TokenIntents.INTENT_API,
description=f"Autogenerated by authentik for Outpost {self.name}",
expiring=False,
managed=managed,
)
except IntegrityError:
# Integrity error happens mostly when managed is reused
Token.objects.filter(managed=managed).delete()
Token.objects.filter(identifier=self.token_identifier).delete()
return self.token
token: Token | None = tokens.first()
if token:
return token
return Token.objects.create(
user=self.user,
identifier=self.token_identifier,
intent=TokenIntents.INTENT_API,
description=f"Autogenerated by authentik for Outpost {self.name}",
expiring=False,
managed=managed,
)
def get_required_objects(self) -> Iterable[models.Model | str]:
"""Get an iterator of all objects the user needs read access to"""

View File

@ -108,7 +108,7 @@ class EventMatcherPolicy(Policy):
result=result,
)
matches.append(result)
passing = all(x.passing for x in matches)
passing = any(x.passing for x in matches)
messages = chain(*[x.messages for x in matches])
result = PolicyResult(passing, *messages)
result.source_results = matches

View File

@ -77,24 +77,11 @@ class TestEventMatcherPolicy(TestCase):
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.5", app="foo"
client_ip="1.2.3.5", app="bar"
)
response = policy.passes(request)
self.assertFalse(response.passing)
def test_multiple(self):
"""Test multiple"""
event = Event.new(EventAction.LOGIN)
event.app = "foo"
event.client_ip = "1.2.3.4"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.4", app="foo"
)
response = policy.passes(request)
self.assertTrue(response.passing)
def test_invalid(self):
"""Test passing event"""
request = PolicyRequest(get_anonymous_user())

View File

@ -39,7 +39,6 @@ class OAuth2ProviderSerializer(ProviderSerializer):
"refresh_token_validity",
"include_claims_in_id_token",
"signing_key",
"encryption_key",
"redirect_uris",
"sub_mode",
"property_mappings",

View File

@ -1,7 +1,6 @@
"""id_token utils"""
from dataclasses import asdict, dataclass, field
from hashlib import sha256
from typing import TYPE_CHECKING, Any
from django.db import models
@ -24,13 +23,8 @@ if TYPE_CHECKING:
from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider
def hash_session_key(session_key: str) -> str:
"""Hash the session key for inclusion in JWTs as `sid`"""
return sha256(session_key.encode("ascii")).hexdigest()
class SubModes(models.TextChoices):
"""Mode after which 'sub' attribute is generated, for compatibility reasons"""
"""Mode after which 'sub' attribute is generateed, for compatibility reasons"""
HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID")
USER_ID = "user_id", _("Based on user ID")
@ -57,8 +51,7 @@ class IDToken:
and potentially other requested Claims. The ID Token is represented as a
JSON Web Token (JWT) [JWT].
https://openid.net/specs/openid-connect-core-1_0.html#IDToken
https://www.iana.org/assignments/jwt/jwt.xhtml"""
https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
# Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
iss: str | None = None
@ -86,8 +79,6 @@ class IDToken:
nonce: str | None = None
# Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html
at_hash: str | None = None
# Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents
sid: str | None = None
claims: dict[str, Any] = field(default_factory=dict)
@ -125,11 +116,9 @@ class IDToken:
now = timezone.now()
id_token.iat = int(now.timestamp())
id_token.auth_time = int(token.auth_time.timestamp())
if token.session:
id_token.sid = hash_session_key(token.session.session_key)
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
auth_event = get_login_event(token.session)
auth_event = get_login_event(request)
if auth_event:
# Also check which method was used for authentication
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")

View File

@ -3,7 +3,6 @@
import django.db.models.deletion
from django.apps.registry import Apps
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
import authentik.lib.utils.time
@ -15,7 +14,7 @@ scope_uid_map = {
}
def set_managed_flag(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
def set_managed_flag(apps: Apps, schema_editor):
ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping")
db_alias = schema_editor.connection.alias
for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "):

View File

@ -1,42 +0,0 @@
# Generated by Django 5.0.9 on 2024-10-16 14:53
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_crypto", "0004_alter_certificatekeypair_name"),
(
"authentik_providers_oauth2",
"0020_remove_accesstoken_authentik_p_token_4bc870_idx_and_more",
),
]
operations = [
migrations.AddField(
model_name="oauth2provider",
name="encryption_key",
field=models.ForeignKey(
help_text="Key used to encrypt the tokens. When set, tokens will be encrypted and returned as JWEs.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="oauth2provider_encryption_key_set",
to="authentik_crypto.certificatekeypair",
verbose_name="Encryption Key",
),
),
migrations.AlterField(
model_name="oauth2provider",
name="signing_key",
field=models.ForeignKey(
help_text="Key used to sign the tokens.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="oauth2provider_signing_key_set",
to="authentik_crypto.certificatekeypair",
verbose_name="Signing Key",
),
),
]

View File

@ -1,113 +0,0 @@
# Generated by Django 5.0.9 on 2024-10-23 13:38
from hashlib import sha256
import django.db.models.deletion
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from authentik.lib.migrations import progress_bar
def migrate_session(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
AuthenticatedSession = apps.get_model("authentik_core", "authenticatedsession")
AuthorizationCode = apps.get_model("authentik_providers_oauth2", "authorizationcode")
AccessToken = apps.get_model("authentik_providers_oauth2", "accesstoken")
RefreshToken = apps.get_model("authentik_providers_oauth2", "refreshtoken")
db_alias = schema_editor.connection.alias
print(f"\nFetching session keys, this might take a couple of minutes...")
session_ids = {}
for session in progress_bar(AuthenticatedSession.objects.using(db_alias).all()):
session_ids[sha256(session.session_key.encode("ascii")).hexdigest()] = session.session_key
for model in [AuthorizationCode, AccessToken, RefreshToken]:
print(
f"\nAdding session to {model._meta.verbose_name}, this might take a couple of minutes..."
)
for code in progress_bar(model.objects.using(db_alias).all()):
if code.session_id_old not in session_ids:
continue
code.session = (
AuthenticatedSession.objects.using(db_alias)
.filter(session_key=session_ids[code.session_id_old])
.first()
)
code.save()
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0040_provider_invalidation_flow"),
("authentik_providers_oauth2", "0021_oauth2provider_encryption_key_and_more"),
]
operations = [
migrations.RenameField(
model_name="accesstoken",
old_name="session_id",
new_name="session_id_old",
),
migrations.RenameField(
model_name="authorizationcode",
old_name="session_id",
new_name="session_id_old",
),
migrations.RenameField(
model_name="refreshtoken",
old_name="session_id",
new_name="session_id_old",
),
migrations.AddField(
model_name="accesstoken",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.AddField(
model_name="authorizationcode",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.AddField(
model_name="devicetoken",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.AddField(
model_name="refreshtoken",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.RunPython(migrate_session),
migrations.RemoveField(
model_name="accesstoken",
name="session_id_old",
),
migrations.RemoveField(
model_name="authorizationcode",
name="session_id_old",
),
migrations.RemoveField(
model_name="refreshtoken",
name="session_id_old",
),
]

View File

@ -18,21 +18,12 @@ from django.http import HttpRequest
from django.templatetags.static import static
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from jwcrypto.common import json_encode
from jwcrypto.jwe import JWE
from jwcrypto.jwk import JWK
from jwt import encode
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.brands.models import WebfingerProvider
from authentik.core.models import (
AuthenticatedSession,
ExpiringModel,
PropertyMapping,
Provider,
User,
)
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
from authentik.lib.models import SerializerModel
@ -215,19 +206,9 @@ class OAuth2Provider(WebfingerProvider, Provider):
verbose_name=_("Signing Key"),
on_delete=models.SET_NULL,
null=True,
help_text=_("Key used to sign the tokens."),
related_name="oauth2provider_signing_key_set",
)
encryption_key = models.ForeignKey(
CertificateKeyPair,
verbose_name=_("Encryption Key"),
on_delete=models.SET_NULL,
null=True,
help_text=_(
"Key used to encrypt the tokens. When set, "
"tokens will be encrypted and returned as JWEs."
"Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
),
related_name="oauth2provider_encryption_key_set",
)
jwks_sources = models.ManyToManyField(
@ -306,27 +287,7 @@ class OAuth2Provider(WebfingerProvider, Provider):
if self.signing_key:
headers["kid"] = self.signing_key.kid
key, alg = self.jwt_key
encoded = encode(payload, key, algorithm=alg, headers=headers)
if self.encryption_key:
return self.encrypt(encoded)
return encoded
def encrypt(self, raw: str) -> str:
"""Encrypt JWT"""
key = JWK.from_pem(self.encryption_key.certificate_data.encode())
jwe = JWE(
raw,
json_encode(
{
"alg": "RSA-OAEP-256",
"enc": "A256CBC-HS512",
"typ": "JWE",
"kid": self.encryption_key.kid,
}
),
)
jwe.add_recipient(key)
return jwe.serialize(compact=True)
return encode(payload, key, algorithm=alg, headers=headers)
def webfinger(self, resource: str, request: HttpRequest):
return {
@ -359,9 +320,7 @@ class BaseGrantModel(models.Model):
revoked = models.BooleanField(default=False)
_scope = models.TextField(default="", verbose_name=_("Scopes"))
auth_time = models.DateTimeField(verbose_name="Authentication time")
session = models.ForeignKey(
AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None
)
session_id = models.CharField(default="", blank=True)
class Meta:
abstract = True
@ -499,9 +458,6 @@ class DeviceToken(ExpiringModel):
device_code = models.TextField(default=generate_key)
user_code = models.TextField(default=generate_code_fixed_length)
_scope = models.TextField(default="", verbose_name=_("Scopes"))
session = models.ForeignKey(
AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None
)
@property
def scope(self) -> list[str]:

View File

@ -1,3 +1,5 @@
from hashlib import sha256
from django.contrib.auth.signals import user_logged_out
from django.dispatch import receiver
from django.http import HttpRequest
@ -11,4 +13,5 @@ def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User,
"""Revoke access tokens upon user logout"""
if not request.session or not request.session.session_key:
return
AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete()
hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest()
AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete()

View File

@ -412,73 +412,6 @@ class TestAuthorize(OAuthTestCase):
delta=5,
)
@apply_blueprint("system/providers-oauth2.yaml")
def test_full_implicit_enc(self):
"""Test full authorization with encryption"""
flow = create_test_flow()
provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="http://localhost",
signing_key=self.keypair,
encryption_key=self.keypair,
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
]
)
)
provider.property_mappings.add(
ScopeMapping.objects.create(
name=generate_id(), scope_name="test", expression="""return {"sub": "foo"}"""
)
)
Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
with patch(
"authentik.providers.oauth2.id_token.get_login_event",
MagicMock(
return_value=Event(
action=EventAction.LOGIN,
context={PLAN_CONTEXT_METHOD: "password"},
created=now(),
)
),
):
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "id_token",
"client_id": "test",
"state": state,
"scope": "openid test",
"redirect_uri": "http://localhost",
"nonce": generate_id(),
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 200)
token: AccessToken = AccessToken.objects.filter(user=user).first()
expires = timedelta_from_string(provider.access_token_validity).total_seconds()
jwt = self.validate_jwe(token, provider)
self.assertEqual(jwt["amr"], ["pwd"])
self.assertEqual(jwt["sub"], "foo")
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,
delta=5,
)
def test_full_fragment_code(self):
"""Test full authorization"""
flow = create_test_flow()

View File

@ -93,24 +93,6 @@ class TestJWKS(OAuthTestCase):
self.assertEqual(len(body["keys"]), 1)
PyJWKSet.from_dict(body)
def test_enc(self):
"""Test with JWE"""
provider = OAuth2Provider.objects.create(
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
encryption_key=create_test_cert(PrivateKeyAlg.ECDSA),
)
app = Application.objects.create(name="test", slug="test", provider=provider)
response = self.client.get(
reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug})
)
body = json.loads(response.content.decode())
self.assertEqual(len(body["keys"]), 2)
PyJWKSet.from_dict(body)
def test_ecdsa_coords_mismatched(self):
"""Test JWKS request with ES256"""
cert = CertificateKeyPair.objects.create(

View File

@ -152,36 +152,6 @@ class TestToken(OAuthTestCase):
)
self.validate_jwt(access, provider)
def test_auth_code_enc(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=self.keypair,
encryption_key=self.keypair,
)
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user()
code = AuthorizationCode.objects.create(
code="foobar", provider=provider, user=user, auth_time=timezone.now()
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
"redirect_uri": "http://local.invalid",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertEqual(response.status_code, 200)
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
self.validate_jwe(access, provider)
@apply_blueprint("system/providers-oauth2.yaml")
def test_refresh_token_view(self):
"""test request param"""

View File

@ -34,7 +34,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
self.factory = RequestFactory()
self.cert = create_test_cert()
jwk = JWKSView().get_jwk_for_key(self.cert, "sig")
jwk = JWKSView().get_jwk_for_key(self.cert)
self.source: OAuthSource = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),

View File

@ -3,8 +3,6 @@
from typing import Any
from django.test import TestCase
from jwcrypto.jwe import JWE
from jwcrypto.jwk import JWK
from jwt import decode
from authentik.core.tests.utils import create_test_cert
@ -34,15 +32,6 @@ class OAuthTestCase(TestCase):
if key in container:
self.assertIsNotNone(container[key])
def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate JWEs"""
private_key = JWK.from_pem(provider.encryption_key.key_data.encode())
jwetoken = JWE()
jwetoken.deserialize(token.token, key=private_key)
token.token = jwetoken.payload.decode()
return self.validate_jwt(token, provider)
def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set"""
key, alg = provider.jwt_key

View File

@ -2,6 +2,7 @@
from dataclasses import InitVar, dataclass, field
from datetime import timedelta
from hashlib import sha256
from json import dumps
from re import error as RegexError
from re import fullmatch
@ -15,7 +16,7 @@ from django.utils import timezone
from django.utils.translation import gettext as _
from structlog.stdlib import get_logger
from authentik.core.models import Application, AuthenticatedSession
from authentik.core.models import Application
from authentik.events.models import Event, EventAction
from authentik.events.signals import get_login_event
from authentik.flows.challenge import (
@ -317,9 +318,7 @@ class OAuthAuthorizationParams:
expires=now + timedelta_from_string(self.provider.access_code_validity),
scope=self.scope,
nonce=self.nonce,
session=AuthenticatedSession.objects.filter(
session_key=request.session.session_key
).first(),
session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
)
if self.code_challenge and self.code_challenge_method:
@ -611,9 +610,7 @@ class OAuthFulfillmentStage(StageView):
expires=access_token_expiry,
provider=self.provider,
auth_time=auth_event.created if auth_event else now,
session=AuthenticatedSession.objects.filter(
session_key=self.request.session.session_key
).first(),
session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
)
id_token = IDToken.new(self.provider, token, self.request)

View File

@ -64,42 +64,36 @@ def to_base64url_uint(val: int, min_length: int = 0) -> bytes:
class JWKSView(View):
"""Show RSA Key data for Provider"""
def get_jwk_for_key(self, key: CertificateKeyPair, use: str) -> dict | None:
def get_jwk_for_key(self, key: CertificateKeyPair) -> dict | None:
"""Convert a certificate-key pair into JWK"""
private_key = key.private_key
key_data = None
if not private_key:
return key_data
key_data = {}
if use == "sig":
if isinstance(private_key, RSAPrivateKey):
key_data["alg"] = JWTAlgorithms.RS256
elif isinstance(private_key, EllipticCurvePrivateKey):
key_data["alg"] = JWTAlgorithms.ES256
elif use == "enc":
key_data["alg"] = "RSA-OAEP-256"
key_data["enc"] = "A256CBC-HS512"
if isinstance(private_key, RSAPrivateKey):
public_key: RSAPublicKey = private_key.public_key()
public_numbers = public_key.public_numbers()
key_data["kid"] = key.kid
key_data["kty"] = "RSA"
key_data["use"] = use
key_data["n"] = to_base64url_uint(public_numbers.n).decode()
key_data["e"] = to_base64url_uint(public_numbers.e).decode()
key_data = {
"kid": key.kid,
"kty": "RSA",
"alg": JWTAlgorithms.RS256,
"use": "sig",
"n": to_base64url_uint(public_numbers.n).decode(),
"e": to_base64url_uint(public_numbers.e).decode(),
}
elif isinstance(private_key, EllipticCurvePrivateKey):
public_key: EllipticCurvePublicKey = private_key.public_key()
public_numbers = public_key.public_numbers()
curve_type = type(public_key.curve)
key_data["kid"] = key.kid
key_data["kty"] = "EC"
key_data["use"] = use
key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode()
key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode()
key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name)
key_data = {
"kid": key.kid,
"kty": "EC",
"alg": JWTAlgorithms.ES256,
"use": "sig",
"x": to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode(),
"y": to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode(),
"crv": ec_crv_map.get(curve_type, public_key.curve.name),
}
else:
return key_data
key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")]
@ -119,19 +113,14 @@ class JWKSView(View):
"""Show JWK Key data for Provider"""
application = get_object_or_404(Application, slug=application_slug)
provider: OAuth2Provider = get_object_or_404(OAuth2Provider, pk=application.provider_id)
signing_key: CertificateKeyPair = provider.signing_key
response_data = {}
if signing_key := provider.signing_key:
jwk = self.get_jwk_for_key(signing_key, "sig")
if signing_key:
jwk = self.get_jwk_for_key(signing_key)
if jwk:
response_data.setdefault("keys", [])
response_data["keys"].append(jwk)
if encryption_key := provider.encryption_key:
jwk = self.get_jwk_for_key(encryption_key, "enc")
if jwk:
response_data.setdefault("keys", [])
response_data["keys"].append(jwk)
response_data["keys"] = [jwk]
response = JsonResponse(response_data)
response["Access-Control-Allow-Origin"] = "*"

View File

@ -46,7 +46,7 @@ class ProviderInfoView(View):
if SCOPE_OPENID not in scopes:
scopes.append(SCOPE_OPENID)
_, supported_alg = provider.jwt_key
config = {
return {
"issuer": provider.get_issuer(self.request),
"authorization_endpoint": self.request.build_absolute_uri(
reverse("authentik_providers_oauth2:authorize")
@ -114,10 +114,6 @@ class ProviderInfoView(View):
"claims_parameter_supported": False,
"code_challenge_methods_supported": [PKCE_METHOD_PLAIN, PKCE_METHOD_S256],
}
if provider.encryption_key:
config["id_token_encryption_alg_values_supported"] = ["RSA-OAEP-256"]
config["id_token_encryption_enc_values_supported"] = ["A256CBC-HS512"]
return config
def get_claims(self, provider: OAuth2Provider) -> list[str]:
"""Get a list of supported claims based on configured scope mappings"""

View File

@ -439,14 +439,15 @@ class TokenParams:
# (22 chars being the length of the "template")
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
defaults={
"attributes": {
USER_ATTRIBUTE_GENERATED: True,
},
"last_login": timezone.now(),
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
self.user.save()
self.__check_policy_access(app, request)
Event.new(
@ -470,6 +471,9 @@ class TokenParams:
self.user, created = User.objects.update_or_create(
username=f"{self.provider.name}-{token.get('sub')}",
defaults={
"attributes": {
USER_ATTRIBUTE_GENERATED: True,
},
"last_login": timezone.now(),
"name": (
f"Autogenerated user from application {app.name} (client credentials JWT)"
@ -478,8 +482,6 @@ class TokenParams:
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
self.user.save()
exp = token.get("exp")
if created and exp:
self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp
@ -550,7 +552,7 @@ class TokenView(View):
# Keep same scopes as previous token
scope=self.params.authorization_code.scope,
auth_time=self.params.authorization_code.auth_time,
session=self.params.authorization_code.session,
session_id=self.params.authorization_code.session_id,
)
access_id_token = IDToken.new(
self.provider,
@ -578,7 +580,7 @@ class TokenView(View):
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.authorization_code.auth_time,
session=self.params.authorization_code.session,
session_id=self.params.authorization_code.session_id,
)
id_token = IDToken.new(
self.provider,
@ -611,7 +613,7 @@ class TokenView(View):
# Keep same scopes as previous token
scope=self.params.refresh_token.scope,
auth_time=self.params.refresh_token.auth_time,
session=self.params.refresh_token.session,
session_id=self.params.refresh_token.session_id,
)
access_token.id_token = IDToken.new(
self.provider,
@ -627,7 +629,7 @@ class TokenView(View):
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.refresh_token.auth_time,
session=self.params.refresh_token.session,
session_id=self.params.refresh_token.session_id,
)
id_token = IDToken.new(
self.provider,
@ -685,14 +687,13 @@ class TokenView(View):
raise DeviceCodeError("authorization_pending")
now = timezone.now()
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
auth_event = get_login_event(self.params.device_code.session)
auth_event = get_login_event(self.request)
access_token = AccessToken(
provider=self.provider,
user=self.params.device_code.user,
expires=access_token_expiry,
scope=self.params.device_code.scope,
auth_time=auth_event.created if auth_event else now,
session=self.params.device_code.session,
)
access_token.id_token = IDToken.new(
self.provider,

View File

@ -1,12 +1,13 @@
"""proxy provider tasks"""
from hashlib import sha256
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.oauth2.id_token import hash_session_key
from authentik.providers.proxy.models import ProxyProvider
from authentik.root.celery import CELERY_APP
@ -25,7 +26,7 @@ def proxy_set_defaults():
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id)
hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(

View File

@ -50,7 +50,6 @@ class AssertionProcessor:
_issue_instant: str
_assertion_id: str
_response_id: str
_valid_not_before: str
_session_not_on_or_after: str
@ -63,7 +62,6 @@ class AssertionProcessor:
self._issue_instant = get_time_string()
self._assertion_id = get_random_id()
self._response_id = get_random_id()
self._valid_not_before = get_time_string(
timedelta_from_string(self.provider.assertion_valid_not_before)
@ -132,9 +130,7 @@ class AssertionProcessor:
"""Generate AuthnStatement with AuthnContext and ContextClassRef Elements."""
auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement")
auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before
auth_n_statement.attrib["SessionIndex"] = sha256(
self.http_request.session.session_key.encode("ascii")
).hexdigest()
auth_n_statement.attrib["SessionIndex"] = self._assertion_id
auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after
auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext")
@ -289,7 +285,7 @@ class AssertionProcessor:
response.attrib["Version"] = "2.0"
response.attrib["IssueInstant"] = self._issue_instant
response.attrib["Destination"] = self.provider.acs_url
response.attrib["ID"] = self._response_id
response.attrib["ID"] = get_random_id()
if self.auth_n_request.id:
response.attrib["InResponseTo"] = self.auth_n_request.id
@ -312,7 +308,7 @@ class AssertionProcessor:
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + element.attrib["ID"],
uri="#" + self._assertion_id,
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)

View File

@ -180,10 +180,6 @@ class TestAuthNRequest(TestCase):
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Ensure both response and assertion ID are in the response twice (once as ID attribute,
# once as ds:Reference URI)
self.assertEqual(response.count(response_proc._assertion_id), 2)
self.assertEqual(response.count(response_proc._response_id), 2)
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)

View File

@ -2,10 +2,9 @@
from itertools import batched
from django.db import transaction
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp
from pydanticscim.responses import PatchOp, PatchOperation
from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager
@ -20,7 +19,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
SCIMRequestException,
)
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import (
SCIMMapping,
@ -105,47 +104,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
provider=self.provider, group=group, scim_id=scim_id
)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(connection, users)
self._patch_add_users(group, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group"""
scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id
try:
if self._config.patch.supported:
return self._update_patch(group, scim_group, connection)
return self._update_put(group, scim_group, connection)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
def _update_patch(
self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup
):
"""Update a group via PATCH request"""
# Patch group's attributes instead of replacing it and re-adding users if we can
self._request(
"PATCH",
f"/Groups/{connection.scim_id}",
json=PatchRequest(
Operations=[
PatchOperation(
op=PatchOp.replace,
path=None,
value=scim_group.model_dump(mode="json", exclude_unset=True),
)
]
).model_dump(
mode="json",
exclude_unset=True,
exclude_none=True,
),
)
return self.patch_compare_users(group)
def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup):
"""Update a group via PUT request"""
try:
self._request(
"PUT",
@ -155,25 +120,33 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True,
),
)
return self.patch_compare_users(group)
users = list(group.users.order_by("id").values_list("id", flat=True))
return self._patch_add_users(group, users)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
except (SCIMRequestException, ObjectExistsSyncException):
# Some providers don't support PUT on groups, so this is mainly a fix for the initial
# sync, send patch add requests for all the users the group currently has
return self._update_patch(group, scim_group, connection)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
# Also update the group name
return self._patch(
scim_group.id,
PatchOperation(
op=PatchOp.replace,
path="displayName",
value=scim_group.displayName,
),
)
def update_group(self, group: Group, action: Direction, users_set: set[int]):
"""Update a group, either using PUT to replace it or PATCH if supported"""
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
if self._config.patch.supported:
if action == Direction.add:
return self._patch_add_users(scim_group, users_set)
return self._patch_add_users(group, users_set)
if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set)
return self._patch_remove_users(group, users_set)
try:
return self.write(group)
except SCIMRequestException as exc:
@ -181,24 +154,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
# Assume that provider does not support PUT and also doesn't support
# ServiceProviderConfig, so try PATCH as a fallback
if action == Direction.add:
return self._patch_add_users(scim_group, users_set)
return self._patch_add_users(group, users_set)
if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set)
return self._patch_remove_users(group, users_set)
raise exc
def _patch_chunked(
def _patch(
self,
group_id: str,
*ops: PatchOperation,
):
"""Helper function that chunks patch requests based on the maxOperations attribute.
This is not strictly according to specs but there's nothing in the schema that allows the
us to know what the maximum patch operations per request should be."""
chunk_size = self._config.bulk.maxOperations
if chunk_size < 1:
chunk_size = len(ops)
if len(ops) < 1:
return
for chunk in batched(ops, chunk_size):
req = PatchRequest(Operations=list(chunk))
self._request(
@ -209,70 +177,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
),
)
@transaction.atomic
def patch_compare_users(self, group: Group):
"""Compare users with a SCIM group and add/remove any differences"""
# Get scim group first
def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
# Get a list of all users in the authentik group
raw_users_should = list(group.users.order_by("id").values_list("id", flat=True))
# Lookup the SCIM IDs of the users
users_should: list[str] = list(
SCIMProviderUser.objects.filter(
user__pk__in=raw_users_should, provider=self.provider
).values_list("scim_id", flat=True)
)
if len(raw_users_should) != len(users_should):
self.logger.warning(
"User count mismatch, not all users in the group are synced to SCIM yet.",
group=group,
)
# Get current group status
current_group = SCIMGroupSchema.model_validate(
self._request("GET", f"/Groups/{scim_group.scim_id}")
)
users_to_add = []
users_to_remove = []
# Check users currently in group and if they shouldn't be in the group and remove them
for user in current_group.members or []:
if user.value not in users_should:
users_to_remove.append(user.value)
# Check users that should be in the group and add them
for user in users_should:
if len([x for x in current_group.members if x.value == user]) < 1:
users_to_add.append(user)
# Only send request if we need to make changes
if len(users_to_add) < 1 and len(users_to_remove) < 1:
return
return self._patch_chunked(
scim_group.scim_id,
*[
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x}],
)
for x in users_to_add
],
*[
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x}],
)
for x in users_to_remove
],
)
def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -280,7 +194,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch_chunked(
self._patch(
scim_group.scim_id,
*[
PatchOperation(
@ -292,10 +206,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
],
)
def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
def _patch_remove_users(self, group: Group, users_set: set[int]):
"""Remove users in users_set from group"""
if len(users_set) < 1:
return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -303,7 +223,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch_chunked(
self._patch(
scim_group.scim_id,
*[
PatchOperation(

View File

@ -2,7 +2,6 @@
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk as BaseBulk
@ -69,12 +68,6 @@ class PatchRequest(BasePatchRequest):
schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",)
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
path: str | None
class SCIMError(BaseSCIMError):
"""SCIM error with optional status code"""

View File

@ -252,118 +252,3 @@ class SCIMMembershipTests(TestCase):
],
},
)
def test_member_add_save(self):
"""Test member add + save"""
config = ServiceProviderConfiguration.default()
config.patch.supported = True
user_scim_id = generate_id()
group_scim_id = generate_id()
uid = generate_id()
group = Group.objects.create(
name=uid,
)
user = User.objects.create(username=generate_id())
# Test initial sync of group creation
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.post(
"https://localhost/Users",
json={
"id": user_scim_id,
},
)
mocker.post(
"https://localhost/Groups",
json={
"id": group_scim_id,
},
)
self.configure()
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual(
mocker.request_history[3].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [],
"active": True,
"externalId": user.uid,
"name": {"familyName": " ", "formatted": " ", "givenName": ""},
"displayName": "",
"userName": user.username,
},
)
self.assertJSONEqual(
mocker.request_history[5].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
)
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.get(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
mocker.patch(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
group.users.add(user)
group.save()
self.assertEqual(mocker.call_count, 5)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "PATCH")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "PATCH")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
"path": "members",
"value": [{"value": user_scim_id}],
}
],
},
)
self.assertJSONEqual(
mocker.request_history[3].body,
{
"Operations": [
{
"op": "replace",
"value": {
"id": group_scim_id,
"displayName": group.name,
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
},
}
]
},
)

View File

@ -41,9 +41,7 @@ class SessionMiddleware(UpstreamSessionMiddleware):
# Since go does not consider localhost with http a secure origin
# we can't set the secure flag.
user_agent = request.META.get("HTTP_USER_AGENT", "")
if user_agent.startswith("goauthentik.io/outpost/") or (
"safari" in user_agent.lower() and "chrome" not in user_agent.lower()
):
if user_agent.startswith("goauthentik.io/outpost/") or "safari" in user_agent.lower():
return False
return True
return False

View File

@ -38,7 +38,6 @@ LANGUAGE_COOKIE_NAME = "authentik_language"
SESSION_COOKIE_NAME = "authentik_session"
SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None)
APPEND_SLASH = False
X_FRAME_OPTIONS = "SAMEORIGIN"
AUTHENTICATION_BACKENDS = [
"django.contrib.auth.backends.ModelBackend",
@ -91,7 +90,6 @@ TENANT_APPS = [
"authentik.providers.scim",
"authentik.rbac",
"authentik.recovery",
"authentik.sources.kerberos",
"authentik.sources.ldap",
"authentik.sources.oauth",
"authentik.sources.plex",

View File

@ -1,31 +0,0 @@
"""Kerberos Property Mapping API"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.sources.kerberos.models import KerberosSourcePropertyMapping
class KerberosSourcePropertyMappingSerializer(PropertyMappingSerializer):
"""Kerberos PropertyMapping Serializer"""
class Meta(PropertyMappingSerializer.Meta):
model = KerberosSourcePropertyMapping
class KerberosSourcePropertyMappingFilter(PropertyMappingFilterSet):
"""Filter for KerberosSourcePropertyMapping"""
class Meta(PropertyMappingFilterSet.Meta):
model = KerberosSourcePropertyMapping
class KerberosSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet):
"""KerberosSource PropertyMapping Viewset"""
queryset = KerberosSourcePropertyMapping.objects.all()
serializer_class = KerberosSourcePropertyMappingSerializer
filterset_class = KerberosSourcePropertyMappingFilter
search_fields = ["name"]
ordering = ["name"]

View File

@ -1,114 +0,0 @@
"""Source API Views"""
from django.core.cache import cache
from drf_spectacular.utils import extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.fields import BooleanField, SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.sources.kerberos.models import KerberosSource
from authentik.sources.kerberos.tasks import CACHE_KEY_STATUS
class KerberosSourceSerializer(SourceSerializer):
"""Kerberos Source Serializer"""
connectivity = SerializerMethodField()
def get_connectivity(self, source: KerberosSource) -> dict[str, str] | None:
"""Get cached source connectivity"""
return cache.get(CACHE_KEY_STATUS + source.slug, None)
class Meta:
model = KerberosSource
fields = SourceSerializer.Meta.fields + [
"group_matching_mode",
"realm",
"krb5_conf",
"sync_users",
"sync_users_password",
"sync_principal",
"sync_password",
"sync_keytab",
"sync_ccache",
"connectivity",
"spnego_server_name",
"spnego_keytab",
"spnego_ccache",
"password_login_update_internal_password",
]
extra_kwargs = {
"sync_password": {"write_only": True},
"sync_keytab": {"write_only": True},
"spnego_keytab": {"write_only": True},
}
class KerberosSyncStatusSerializer(PassiveSerializer):
"""Kerberos Source sync status"""
is_running = BooleanField(read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class KerberosSourceViewSet(UsedByMixin, ModelViewSet):
"""Kerberos Source Viewset"""
queryset = KerberosSource.objects.all()
serializer_class = KerberosSourceSerializer
lookup_field = "slug"
filterset_fields = [
"name",
"slug",
"enabled",
"realm",
"sync_users",
"sync_users_password",
"sync_principal",
"spnego_server_name",
"password_login_update_internal_password",
]
search_fields = [
"name",
"slug",
"realm",
"krb5_conf",
"sync_principal",
"spnego_server_name",
]
ordering = ["name"]
@extend_schema(
responses={
200: KerberosSyncStatusSerializer(),
}
)
@action(
methods=["GET"],
detail=True,
pagination_class=None,
url_path="sync/status",
filter_backends=[],
)
def sync_status(self, request: Request, slug: str) -> Response:
"""Get source's sync status"""
source: KerberosSource = self.get_object()
tasks = list(
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
name="kerberos_sync",
uid__startswith=source.slug,
)
)
with source.sync_lock as lock_acquired:
status = {
"tasks": tasks,
"is_running": not lock_acquired,
}
return Response(KerberosSyncStatusSerializer(status).data)

View File

@ -1,51 +0,0 @@
"""Kerberos Source Serializer"""
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.viewsets import ModelViewSet
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
from authentik.core.api.sources import (
GroupSourceConnectionSerializer,
GroupSourceConnectionViewSet,
UserSourceConnectionSerializer,
)
from authentik.core.api.used_by import UsedByMixin
from authentik.sources.kerberos.models import (
GroupKerberosSourceConnection,
UserKerberosSourceConnection,
)
class UserKerberosSourceConnectionSerializer(UserSourceConnectionSerializer):
"""Kerberos Source Serializer"""
class Meta:
model = UserKerberosSourceConnection
fields = UserSourceConnectionSerializer.Meta.fields + ["identifier"]
class UserKerberosSourceConnectionViewSet(UsedByMixin, ModelViewSet):
"""Source Viewset"""
queryset = UserKerberosSourceConnection.objects.all()
serializer_class = UserKerberosSourceConnectionSerializer
filterset_fields = ["source__slug"]
search_fields = ["source__slug"]
permission_classes = [OwnerSuperuserPermissions]
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
ordering = ["source__slug"]
class GroupKerberosSourceConnectionSerializer(GroupSourceConnectionSerializer):
"""OAuth Group-Source connection Serializer"""
class Meta(GroupSourceConnectionSerializer.Meta):
model = GroupKerberosSourceConnection
class GroupKerberosSourceConnectionViewSet(GroupSourceConnectionViewSet):
"""Group-source connection Viewset"""
queryset = GroupKerberosSourceConnection.objects.all()
serializer_class = GroupKerberosSourceConnectionSerializer

View File

@ -1,13 +0,0 @@
"""authentik kerberos source config"""
from authentik.blueprints.apps import ManagedAppConfig
class AuthentikSourceKerberosConfig(ManagedAppConfig):
"""Authentik source kerberos app config"""
name = "authentik.sources.kerberos"
label = "authentik_sources_kerberos"
verbose_name = "authentik Sources.Kerberos"
mountpoint = "source/kerberos/"
default = True

View File

@ -1,116 +0,0 @@
"""authentik Kerberos Authentication Backend"""
import gssapi
from django.http import HttpRequest
from structlog.stdlib import get_logger
from authentik.core.auth import InbuiltBackend
from authentik.core.models import User
from authentik.lib.generators import generate_id
from authentik.sources.kerberos.models import (
KerberosSource,
Krb5ConfContext,
UserKerberosSourceConnection,
)
LOGGER = get_logger()
class KerberosBackend(InbuiltBackend):
"""Authenticate users against Kerberos realm"""
def authenticate(self, request: HttpRequest, **kwargs):
"""Try to authenticate a user via kerberos"""
if "password" not in kwargs or "username" not in kwargs:
return None
username = kwargs.pop("username")
realm = None
if "@" in username:
username, realm = username.rsplit("@", 1)
user, source = self.auth_user(username, realm, **kwargs)
if user:
self.set_method("kerberos", request, source=source)
return user
return None
def auth_user(
self, username: str, realm: str | None, password: str, **filters
) -> tuple[User | None, KerberosSource | None]:
sources = KerberosSource.objects.filter(enabled=True)
user = User.objects.filter(usersourceconnection__source__in=sources, **filters).first()
if user is not None:
# User found, let's get its connections for the sources that are available
user_source_connections = UserKerberosSourceConnection.objects.filter(
user=user, source__in=sources
)
elif realm is not None:
user_source_connections = UserKerberosSourceConnection.objects.filter(
source__in=sources, identifier=f"{username}@{realm}"
)
# no realm specified, we can't do anything
else:
user_source_connections = UserKerberosSourceConnection.objects.none()
if not user_source_connections.exists():
LOGGER.debug("no kerberos source found for user", username=username)
return None, None
for user_source_connection in user_source_connections.prefetch_related().select_related(
"source__kerberossource"
):
# User either has an unusable password,
# or has a password, but couldn't be authenticated by ModelBackend
# This means we check with a kinit to see if the Kerberos password has changed
if self.auth_user_by_kinit(user_source_connection, password):
# Password was successful in kinit to Kerberos, so we save it in database
if (
user_source_connection.source.kerberossource.password_login_update_internal_password
):
LOGGER.debug(
"Updating user's password in DB",
source=user_source_connection.source,
user=user_source_connection.user,
)
user_source_connection.user.set_password(
password, sender=user_source_connection.source
)
user_source_connection.user.save()
return user, user_source_connection.source
# Password doesn't match, onto next source
LOGGER.debug(
"failed to kinit, password invalid",
source=user_source_connection.source,
user=user_source_connection.user,
)
# No source with valid password found
LOGGER.debug("no valid kerberos source found for user", user=user)
return None, None
def auth_user_by_kinit(
self, user_source_connection: UserKerberosSourceConnection, password: str
) -> bool:
"""Attempt authentication by kinit to the source."""
LOGGER.debug(
"Attempting to kinit as user",
user=user_source_connection.user,
source=user_source_connection.source,
principal=user_source_connection.identifier,
)
with Krb5ConfContext(user_source_connection.source.kerberossource):
name = gssapi.raw.import_name(
user_source_connection.identifier.encode(), gssapi.raw.NameType.kerberos_principal
)
try:
# Use a temporary credentials cache to not interfere with whatever is defined
# elsewhere
gssapi.raw.ext_krb5.krb5_ccache_name(f"MEMORY:{generate_id(12)}".encode())
gssapi.raw.ext_password.acquire_cred_with_password(name, password.encode())
# Restore the credentials cache to what it was before
gssapi.raw.ext_krb5.krb5_ccache_name(None)
return True
except gssapi.exceptions.GSSError as exc:
LOGGER.warning("failed to kinit", exc=exc)
return False

View File

@ -1,4 +0,0 @@
[libdefaults]
dns_canonicalize_hostname = false
dns_fallback = true
rnds = false

View File

@ -1,25 +0,0 @@
"""Kerberos Connection check"""
from json import dumps
from structlog.stdlib import get_logger
from authentik.sources.kerberos.models import KerberosSource
from authentik.tenants.management import TenantCommand
LOGGER = get_logger()
class Command(TenantCommand):
"""Check connectivity to Kerberos servers for a source"""
def add_arguments(self, parser):
parser.add_argument("source_slugs", nargs="?", type=str)
def handle_per_tenant(self, **options):
sources = KerberosSource.objects.filter(enabled=True)
if options["source_slugs"]:
sources = KerberosSource.objects.filter(slug__in=options["source_slugs"])
for source in sources.order_by("slug"):
status = source.check_connection()
self.stdout.write(dumps(status, indent=4))

View File

@ -1,25 +0,0 @@
"""Kerberos Sync"""
from structlog.stdlib import get_logger
from authentik.sources.kerberos.models import KerberosSource
from authentik.sources.kerberos.sync import KerberosSync
from authentik.tenants.management import TenantCommand
LOGGER = get_logger()
class Command(TenantCommand):
"""Run sync for an Kerberos Source"""
def add_arguments(self, parser):
parser.add_argument("source_slugs", nargs="+", type=str)
def handle_per_tenant(self, **options):
for source_slug in options["source_slugs"]:
source = KerberosSource.objects.filter(slug=source_slug).first()
if not source:
LOGGER.warning("Source does not exist", slug=source_slug)
continue
user_count = KerberosSync(source).sync()
LOGGER.info(f"Synced {user_count} users", slug=source_slug)

View File

@ -1,179 +0,0 @@
# Generated by Django 5.0.9 on 2024-09-23 11:27
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"),
]
operations = [
migrations.CreateModel(
name="GroupKerberosSourceConnection",
fields=[
(
"groupsourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.groupsourceconnection",
),
),
],
options={
"verbose_name": "Group Kerberos Source Connection",
"verbose_name_plural": "Group Kerberos Source Connections",
},
bases=("authentik_core.groupsourceconnection",),
),
migrations.CreateModel(
name="KerberosSource",
fields=[
(
"source_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.source",
),
),
("realm", models.TextField(help_text="Kerberos realm", unique=True)),
(
"krb5_conf",
models.TextField(
blank=True,
help_text="Custom krb5.conf to use. Uses the system one by default",
),
),
(
"sync_users",
models.BooleanField(
db_index=True,
default=False,
help_text="Sync users from Kerberos into authentik",
),
),
(
"sync_users_password",
models.BooleanField(
db_index=True,
default=True,
help_text="When a user changes their password, sync it back to Kerberos",
),
),
(
"sync_principal",
models.TextField(
blank=True, help_text="Principal to authenticate to kadmin for sync."
),
),
(
"sync_password",
models.TextField(
blank=True, help_text="Password to authenticate to kadmin for sync"
),
),
(
"sync_keytab",
models.TextField(
blank=True,
help_text="Keytab to authenticate to kadmin for sync. Must be base64-encoded or in the form TYPE:residual",
),
),
(
"sync_ccache",
models.TextField(
blank=True,
help_text="Credentials cache to authenticate to kadmin for sync. Must be in the form TYPE:residual",
),
),
(
"spnego_server_name",
models.TextField(
blank=True,
help_text="Force the use of a specific server name for SPNEGO. Must be in the form HTTP@hostname",
),
),
(
"spnego_keytab",
models.TextField(
blank=True,
help_text="SPNEGO keytab base64-encoded or path to keytab in the form FILE:path",
),
),
(
"spnego_ccache",
models.TextField(
blank=True,
help_text="Credential cache to use for SPNEGO in form type:residual",
),
),
(
"password_login_update_internal_password",
models.BooleanField(
default=False,
help_text="If enabled, the authentik-stored password will be updated upon login with the Kerberos password backend",
),
),
],
options={
"verbose_name": "Kerberos Source",
"verbose_name_plural": "Kerberos Sources",
},
bases=("authentik_core.source",),
),
migrations.CreateModel(
name="KerberosSourcePropertyMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
],
options={
"verbose_name": "Kerberos Source Property Mapping",
"verbose_name_plural": "Kerberos Source Property Mappings",
},
bases=("authentik_core.propertymapping",),
),
migrations.CreateModel(
name="UserKerberosSourceConnection",
fields=[
(
"usersourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.usersourceconnection",
),
),
("identifier", models.TextField()),
],
options={
"verbose_name": "User Kerberos Source Connection",
"verbose_name_plural": "User Kerberos Source Connections",
},
bases=("authentik_core.usersourceconnection",),
),
]

View File

@ -1,376 +0,0 @@
"""authentik Kerberos Source Models"""
import os
from pathlib import Path
from tempfile import gettempdir
from typing import Any
import gssapi
import kadmin
import pglock
from django.db import connection, models
from django.db.models.fields import b64decode
from django.http import HttpRequest
from django.shortcuts import reverse
from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.core.models import (
GroupSourceConnection,
PropertyMapping,
Source,
UserSourceConnection,
UserTypes,
)
from authentik.core.types import UILoginButton, UserSettingSerializer
from authentik.flows.challenge import RedirectChallenge
LOGGER = get_logger()
# python-kadmin leaks file descriptors. As such, this global is used to reuse
# existing kadmin connections instead of creating new ones, which results in less to no file
# descriptors leaks
_kadmin_connections: dict[str, Any] = {}
class KerberosSource(Source):
"""Federate Kerberos realm with authentik"""
realm = models.TextField(help_text=_("Kerberos realm"), unique=True)
krb5_conf = models.TextField(
blank=True,
help_text=_("Custom krb5.conf to use. Uses the system one by default"),
)
sync_users = models.BooleanField(
default=False, help_text=_("Sync users from Kerberos into authentik"), db_index=True
)
sync_users_password = models.BooleanField(
default=True,
help_text=_("When a user changes their password, sync it back to Kerberos"),
db_index=True,
)
sync_principal = models.TextField(
help_text=_("Principal to authenticate to kadmin for sync."), blank=True
)
sync_password = models.TextField(
help_text=_("Password to authenticate to kadmin for sync"), blank=True
)
sync_keytab = models.TextField(
help_text=_(
"Keytab to authenticate to kadmin for sync. "
"Must be base64-encoded or in the form TYPE:residual"
),
blank=True,
)
sync_ccache = models.TextField(
help_text=_(
"Credentials cache to authenticate to kadmin for sync. "
"Must be in the form TYPE:residual"
),
blank=True,
)
spnego_server_name = models.TextField(
help_text=_(
"Force the use of a specific server name for SPNEGO. Must be in the form HTTP@hostname"
),
blank=True,
)
spnego_keytab = models.TextField(
help_text=_("SPNEGO keytab base64-encoded or path to keytab in the form FILE:path"),
blank=True,
)
spnego_ccache = models.TextField(
help_text=_("Credential cache to use for SPNEGO in form type:residual"),
blank=True,
)
password_login_update_internal_password = models.BooleanField(
default=False,
help_text=_(
"If enabled, the authentik-stored password will be updated upon "
"login with the Kerberos password backend"
),
)
class Meta:
verbose_name = _("Kerberos Source")
verbose_name_plural = _("Kerberos Sources")
def __str__(self):
return f"Kerberos Source {self.name}"
@property
def component(self) -> str:
return "ak-source-kerberos-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.kerberos.api.source import KerberosSourceSerializer
return KerberosSourceSerializer
@property
def property_mapping_type(self) -> type[PropertyMapping]:
return KerberosSourcePropertyMapping
@property
def icon_url(self) -> str:
icon = super().icon_url
if not icon:
return static("authentik/sources/kerberos.png")
return icon
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
return UILoginButton(
challenge=RedirectChallenge(
data={
"to": reverse(
"authentik_sources_kerberos:spnego-login",
kwargs={"source_slug": self.slug},
),
}
),
name=self.name,
icon_url=self.icon_url,
)
def ui_user_settings(self) -> UserSettingSerializer | None:
return UserSettingSerializer(
data={
"title": self.name,
"component": "ak-user-settings-source-kerberos",
"configure_url": reverse(
"authentik_sources_kerberos:spnego-login",
kwargs={"source_slug": self.slug},
),
"icon_url": self.icon_url,
}
)
@property
def sync_lock(self) -> pglock.advisory:
"""Redis lock for syncing Kerberos to prevent multiple parallel syncs happening"""
return pglock.advisory(
lock_id=f"goauthentik.io/{connection.schema_name}/sources/kerberos/sync/{self.slug}",
timeout=0,
side_effect=pglock.Return,
)
def get_base_user_properties(self, principal: str, **kwargs):
localpart, _ = principal.rsplit("@", 1)
return {
"username": localpart,
"type": UserTypes.INTERNAL,
"path": self.get_user_path(),
}
def get_base_group_properties(self, group_id: str, **kwargs):
return {
"name": group_id,
}
@property
def tempdir(self) -> Path:
"""Get temporary storage for Kerberos files"""
path = (
Path(gettempdir())
/ "authentik"
/ connection.schema_name
/ "sources"
/ "kerberos"
/ str(self.pk)
)
path.mkdir(mode=0o700, parents=True, exist_ok=True)
return path
@property
def krb5_conf_path(self) -> str | None:
"""Get krb5.conf path"""
if not self.krb5_conf:
return None
conf_path = self.tempdir / "krb5.conf"
conf_path.write_text(self.krb5_conf)
return str(conf_path)
def _kadmin_init(self) -> "kadmin.KAdmin | None":
# kadmin doesn't use a ccache for its connection
# as such, we don't need to create a separate ccache for each source
if not self.sync_principal:
return None
if self.sync_password:
return kadmin.init_with_password(
self.sync_principal,
self.sync_password,
)
if self.sync_keytab:
keytab = self.sync_keytab
if ":" not in keytab:
keytab_path = self.tempdir / "kadmin_keytab"
keytab_path.touch(mode=0o600)
keytab_path.write_bytes(b64decode(self.sync_keytab))
keytab = f"FILE:{keytab_path}"
return kadmin.init_with_keytab(
self.sync_principal,
keytab,
)
if self.sync_ccache:
return kadmin.init_with_ccache(
self.sync_principal,
self.sync_ccache,
)
return None
def connection(self) -> "kadmin.KAdmin | None":
"""Get kadmin connection"""
if str(self.pk) not in _kadmin_connections:
kadm = self._kadmin_init()
if kadm is not None:
_kadmin_connections[str(self.pk)] = self._kadmin_init()
return _kadmin_connections.get(str(self.pk), None)
def check_connection(self) -> dict[str, str]:
"""Check Kerberos Connection"""
status = {"status": "ok"}
if not self.sync_users:
return status
with Krb5ConfContext(self):
try:
kadm = self.connection()
if kadm is None:
status["status"] = "no connection"
return status
status["principal_exists"] = kadm.principal_exists(self.sync_principal)
except kadmin.KAdminError as exc:
status["status"] = str(exc)
return status
def get_gssapi_store(self) -> dict[str, str]:
"""Get GSSAPI credentials store for this source"""
ccache = self.spnego_ccache
keytab = None
if not ccache:
ccache_path = self.tempdir / "spnego_ccache"
ccache_path.touch(mode=0o600)
ccache = f"FILE:{ccache_path}"
if self.spnego_keytab:
# Keytab is of the form type:residual, use as-is
if ":" in self.spnego_keytab:
keytab = self.spnego_keytab
# Parse the keytab and write it in the file
else:
keytab_path = self.tempdir / "spnego_keytab"
keytab_path.touch(mode=0o600)
keytab_path.write_bytes(b64decode(self.spnego_keytab))
keytab = f"FILE:{keytab_path}"
store = {"ccache": ccache}
if keytab is not None:
store["keytab"] = keytab
return store
def get_gssapi_creds(self) -> gssapi.creds.Credentials | None:
"""Get GSSAPI credentials for this source"""
try:
name = None
if self.spnego_server_name:
# pylint: disable=c-extension-no-member
name = gssapi.names.Name(
base=self.spnego_server_name,
name_type=gssapi.raw.types.NameType.hostbased_service,
)
return gssapi.creds.Credentials(
usage="accept", name=name, store=self.get_gssapi_store()
)
except gssapi.exceptions.GSSError as exc:
LOGGER.warn("GSSAPI credentials failure", exc=exc)
return None
class Krb5ConfContext:
"""
Context manager to set the path to the krb5.conf config file.
"""
def __init__(self, source: KerberosSource):
self._source = source
self._path = self._source.krb5_conf_path
self._previous = None
def __enter__(self):
if not self._path:
return
self._previous = os.environ.get("KRB5_CONFIG", None)
os.environ["KRB5_CONFIG"] = self._path
def __exit__(self, *args, **kwargs):
if not self._path:
return
if self._previous:
os.environ["KRB5_CONFIG"] = self._previous
else:
del os.environ["KRB5_CONFIG"]
class KerberosSourcePropertyMapping(PropertyMapping):
"""Map Kerberos Property to User object attribute"""
@property
def component(self) -> str:
return "ak-property-mapping-source-kerberos-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.kerberos.api.property_mappings import (
KerberosSourcePropertyMappingSerializer,
)
return KerberosSourcePropertyMappingSerializer
def __str__(self):
return str(self.name)
class Meta:
verbose_name = _("Kerberos Source Property Mapping")
verbose_name_plural = _("Kerberos Source Property Mappings")
class UserKerberosSourceConnection(UserSourceConnection):
"""Connection to configured Kerberos Sources."""
identifier = models.TextField()
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.kerberos.api.source_connection import (
UserKerberosSourceConnectionSerializer,
)
return UserKerberosSourceConnectionSerializer
class Meta:
verbose_name = _("User Kerberos Source Connection")
verbose_name_plural = _("User Kerberos Source Connections")
class GroupKerberosSourceConnection(GroupSourceConnection):
"""Connection to configured Kerberos Sources."""
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.kerberos.api.source_connection import (
GroupKerberosSourceConnectionSerializer,
)
return GroupKerberosSourceConnectionSerializer
class Meta:
verbose_name = _("Group Kerberos Source Connection")
verbose_name_plural = _("Group Kerberos Source Connections")

View File

@ -1,18 +0,0 @@
"""LDAP Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"sources_kerberos_sync": {
"task": "authentik.sources.kerberos.tasks.kerberos_sync_all",
"schedule": crontab(minute=fqdn_rand("sources_kerberos_sync"), hour="*/2"),
"options": {"queue": "authentik_scheduled"},
},
"sources_kerberos_connectivity_check": {
"task": "authentik.sources.kerberos.tasks.kerberos_connectivity_check",
"schedule": crontab(minute=fqdn_rand("sources_kerberos_connectivity_check"), hour="*"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,61 +0,0 @@
"""authentik kerberos source signals"""
import kadmin
from django.db.models.signals import post_save
from django.dispatch import receiver
from rest_framework.serializers import ValidationError
from structlog.stdlib import get_logger
from authentik.core.models import User
from authentik.core.signals import password_changed
from authentik.events.models import Event, EventAction
from authentik.sources.kerberos.models import (
KerberosSource,
Krb5ConfContext,
UserKerberosSourceConnection,
)
from authentik.sources.kerberos.tasks import kerberos_connectivity_check, kerberos_sync_single
LOGGER = get_logger()
@receiver(post_save, sender=KerberosSource)
def sync_kerberos_source_on_save(sender, instance: KerberosSource, **_):
"""Ensure that source is synced on save (if enabled)"""
if not instance.enabled or not instance.sync_users:
return
kerberos_sync_single.delay(instance.pk)
kerberos_connectivity_check.delay(instance.pk)
@receiver(password_changed)
def kerberos_sync_password(sender, user: User, password: str, **_):
"""Connect to kerberos and update password."""
user_source_connections = UserKerberosSourceConnection.objects.select_related(
"source__kerberossource"
).filter(
user=user,
source__enabled=True,
source__kerberossource__sync_users=True,
source__kerberossource__sync_users_password=True,
)
for user_source_connection in user_source_connections:
source = user_source_connection.source.kerberossource
if source.pk == getattr(sender, "pk", None):
continue
with Krb5ConfContext(source):
try:
source.connection().getprinc(user_source_connection.identifier).change_password(
password
)
except kadmin.KAdminError as exc:
LOGGER.warning("failed to set Kerberos password", exc=exc, source=source)
Event.new(
EventAction.CONFIGURATION_ERROR,
message=(
"Failed to change password in Kerberos source due to remote error: "
f"{exc}"
),
source=source,
).set_user(user).save()
raise ValidationError("Failed to set password") from exc

View File

@ -1,167 +0,0 @@
"""Sync Kerberos users into authentik"""
from typing import Any
import kadmin
from django.core.exceptions import FieldError
from django.db import IntegrityError, transaction
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.models import Group, User, UserTypes
from authentik.core.sources.mapper import SourceMapper
from authentik.core.sources.matcher import Action, SourceMatcher
from authentik.events.models import Event, EventAction
from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.lib.sync.outgoing.exceptions import StopSync
from authentik.sources.kerberos.models import (
GroupKerberosSourceConnection,
KerberosSource,
Krb5ConfContext,
UserKerberosSourceConnection,
)
class KerberosSync:
"""Sync Kerberos users into authentik"""
_source: KerberosSource
_logger: BoundLogger
_connection: "kadmin.KAdmin"
mapper: SourceMapper
user_manager: PropertyMappingManager
group_manager: PropertyMappingManager
matcher: SourceMatcher
def __init__(self, source: KerberosSource):
self._source = source
with Krb5ConfContext(self._source):
self._connection = self._source.connection()
self._messages = []
self._logger = get_logger().bind(source=self._source, syncer=self.__class__.__name__)
self.mapper = SourceMapper(self._source)
self.user_manager = self.mapper.get_manager(User, ["principal"])
self.group_manager = self.mapper.get_manager(Group, ["group_id", "principal"])
self.matcher = SourceMatcher(
self._source, UserKerberosSourceConnection, GroupKerberosSourceConnection
)
@staticmethod
def name() -> str:
"""UI name for the type of object this class synchronizes"""
return "users"
@property
def messages(self) -> list[str]:
"""Get all UI messages"""
return self._messages
def message(self, *args, **kwargs):
"""Add message that is later added to the System Task and shown to the user"""
formatted_message = " ".join(args)
self._messages.append(formatted_message)
self._logger.warning(*args, **kwargs)
def _handle_principal(self, principal: str) -> bool:
try:
defaults = self.mapper.build_object_properties(
object_type=User,
manager=self.user_manager,
user=None,
request=None,
principal=principal,
)
self._logger.debug("Writing user with attributes", **defaults)
if "username" not in defaults:
raise IntegrityError("Username was not set by propertymappings")
action, connection = self.matcher.get_user_action(principal, defaults)
self._logger.debug("Action returned", action=action, connection=connection)
if action == Action.DENY:
return False
group_properties = {
group_id: self.mapper.build_object_properties(
object_type=Group,
manager=self.group_manager,
user=None,
request=None,
group_id=group_id,
principal=principal,
)
for group_id in defaults.pop("groups", [])
}
if action == Action.ENROLL:
user = User.objects.create(**defaults)
if user.type == UserTypes.INTERNAL_SERVICE_ACCOUNT:
user.set_unusable_password()
user.save()
connection.user = user
connection.save()
elif action in (Action.AUTH, Action.LINK):
user = connection.user
user.update_attributes(defaults)
else:
return False
groups: list[Group] = []
for group_id, properties in group_properties.items():
group = self._handle_group(group_id, properties)
if group:
groups.append(group)
with transaction.atomic():
user.ak_groups.remove(
*user.ak_groups.filter(groupsourceconnection__source=self._source)
)
user.ak_groups.add(*groups)
except PropertyMappingExpressionException as exc:
raise StopSync(exc, None, exc.mapping) from exc
except SkipObjectException:
return False
except (IntegrityError, FieldError, TypeError, AttributeError) as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
message=(f"Failed to create user: {str(exc)} "),
source=self._source,
principal=principal,
).save()
return False
self._logger.debug("Synced User", user=user.username)
return True
def _handle_group(
self, group_id: str, defaults: dict[str, Any | dict[str, Any]]
) -> Group | None:
action, connection = self.matcher.get_group_action(group_id, defaults)
if action == Action.DENY:
return None
if action == Action.ENROLL:
group = Group.objects.create(**defaults)
connection.group = group
connection.save()
return group
if action in (Action.AUTH, Action.LINK):
group = connection.group
group.update_attributes(defaults)
connection.save()
return group
return None
def sync(self) -> int:
"""Iterate over all Kerberos users and create authentik_core.User instances"""
if not self._source.enabled or not self._source.sync_users:
self.message("Source is disabled or user syncing is disabled for this Source")
return -1
user_count = 0
with Krb5ConfContext(self._source):
for principal in self._connection.principals():
if self._handle_principal(principal):
user_count += 1
return user_count

View File

@ -1,68 +0,0 @@
"""Kerberos Sync tasks"""
from django.core.cache import cache
from structlog.stdlib import get_logger
from authentik.events.models import SystemTask as DBSystemTask
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.config import CONFIG
from authentik.lib.sync.outgoing.exceptions import StopSync
from authentik.lib.utils.errors import exception_to_string
from authentik.root.celery import CELERY_APP
from authentik.sources.kerberos.models import KerberosSource
from authentik.sources.kerberos.sync import KerberosSync
LOGGER = get_logger()
CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/"
@CELERY_APP.task()
def kerberos_sync_all():
"""Sync all sources"""
for source in KerberosSource.objects.filter(enabled=True, sync_users=True):
kerberos_sync_single.delay(str(source.pk))
@CELERY_APP.task()
def kerberos_connectivity_check(pk: str | None = None):
"""Check connectivity for Kerberos Sources"""
# 2 hour timeout, this task should run every hour
timeout = 60 * 60 * 2
sources = KerberosSource.objects.filter(enabled=True)
if pk:
sources = sources.filter(pk=pk)
for source in sources:
status = source.check_connection()
cache.set(CACHE_KEY_STATUS + source.slug, status, timeout=timeout)
@CELERY_APP.task(
bind=True,
base=SystemTask,
# We take the configured hours timeout time by 2.5 as we run user and
# group in parallel and then membership, so 2x is to cover the serial tasks,
# and 0.5x on top of that to give some more leeway
soft_time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5,
task_time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5,
)
def kerberos_sync_single(self, source_pk: str):
"""Sync a single source"""
source: KerberosSource = KerberosSource.objects.filter(pk=source_pk).first()
if not source or not source.enabled:
return
try:
with source.sync_lock as lock_acquired:
if not lock_acquired:
LOGGER.debug(
"Failed to acquire lock for Kerberos sync, skipping task", source=source.slug
)
return
# Delete all sync tasks from the cache
DBSystemTask.objects.filter(name="kerberos_sync", uid__startswith=source.slug).delete()
syncer = KerberosSync(source)
syncer.sync()
self.set_status(TaskStatus.SUCCESSFUL, *syncer.messages)
except StopSync as exc:
LOGGER.warning(exception_to_string(exc))
self.set_error(exc)

View File

@ -1,57 +0,0 @@
"""Kerberos Source Auth tests"""
from django.contrib.auth.hashers import is_password_usable
from authentik.core.models import User
from authentik.lib.generators import generate_id
from authentik.sources.kerberos.auth import KerberosBackend
from authentik.sources.kerberos.models import KerberosSource, UserKerberosSourceConnection
from authentik.sources.kerberos.tests.utils import KerberosTestCase
class TestKerberosAuth(KerberosTestCase):
"""Kerberos Auth tests"""
def setUp(self):
self.source = KerberosSource.objects.create(
name="kerberos",
slug="kerberos",
realm=self.realm.realm,
sync_users=False,
sync_users_password=False,
password_login_update_internal_password=True,
)
self.user = User.objects.create(username=generate_id())
self.user.set_unusable_password()
UserKerberosSourceConnection.objects.create(
source=self.source, user=self.user, identifier=self.realm.user_princ
)
def test_auth_username(self):
"""Test auth username"""
backend = KerberosBackend()
self.assertEqual(
backend.authenticate(
None, username=self.user.username, password=self.realm.password("user")
),
self.user,
)
def test_auth_principal(self):
"""Test auth principal"""
backend = KerberosBackend()
self.assertEqual(
backend.authenticate(
None, username=self.realm.user_princ, password=self.realm.password("user")
),
self.user,
)
def test_internal_password_update(self):
"""Test internal password update"""
backend = KerberosBackend()
backend.authenticate(
None, username=self.realm.user_princ, password=self.realm.password("user")
)
self.user.refresh_from_db()
self.assertTrue(is_password_usable(self.user.password))

View File

@ -1,78 +0,0 @@
"""Kerberos Source SPNEGO tests"""
from base64 import b64decode, b64encode
from pathlib import Path
import gssapi
from django.urls import reverse
from authentik.core.tests.utils import create_test_admin_user
from authentik.sources.kerberos.models import KerberosSource
from authentik.sources.kerberos.tests.utils import KerberosTestCase
class TestSPNEGOSource(KerberosTestCase):
"""Kerberos Source SPNEGO tests"""
def setUp(self):
self.source = KerberosSource.objects.create(
name="test",
slug="test",
spnego_keytab=b64encode(Path(self.realm.http_keytab).read_bytes()).decode(),
)
# Force store creation early
self.source.get_gssapi_store()
def test_api_read(self):
"""Test reading a source"""
self.client.force_login(create_test_admin_user())
response = self.client.get(
reverse(
"authentik_api:kerberossource-detail",
kwargs={
"slug": self.source.slug,
},
)
)
self.assertEqual(response.status_code, 200)
def test_source_login(self):
"""test login view"""
response = self.client.get(
reverse(
"authentik_sources_kerberos:spnego-login",
kwargs={"source_slug": self.source.slug},
)
)
self.assertEqual(response.status_code, 302)
endpoint = response.headers["Location"]
response = self.client.get(endpoint)
self.assertEqual(response.status_code, 401)
self.assertEqual(response.headers["WWW-Authenticate"], "Negotiate")
server_name = gssapi.names.Name("HTTP/testserver@")
client_creds = gssapi.creds.Credentials(
usage="initiate", store={"ccache": self.realm.ccache}
)
client_ctx = gssapi.sec_contexts.SecurityContext(
name=server_name, usage="initiate", creds=client_creds
)
status = 401
server_token = None
while status == 401 and not client_ctx.complete: # noqa: PLR2004
client_token = client_ctx.step(server_token)
if not client_token:
break
response = self.client.get(
endpoint,
headers={"Authorization": f"Negotiate {b64encode(client_token).decode('ascii')}"},
)
status = response.status_code
if status == 401: # noqa: PLR2004
server_token = b64decode(response.headers["WWW-Authenticate"][9:].strip())
# 400 because no enroll flow
self.assertEqual(status, 400)

View File

@ -1,75 +0,0 @@
"""Kerberos Source sync tests"""
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import User
from authentik.lib.generators import generate_id
from authentik.sources.kerberos.models import KerberosSource, KerberosSourcePropertyMapping
from authentik.sources.kerberos.sync import KerberosSync
from authentik.sources.kerberos.tasks import kerberos_sync_all
from authentik.sources.kerberos.tests.utils import KerberosTestCase
class TestKerberosSync(KerberosTestCase):
"""Kerberos Sync tests"""
@apply_blueprint("system/sources-kerberos.yaml")
def setUp(self):
self.source: KerberosSource = KerberosSource.objects.create(
name="kerberos",
slug="kerberos",
realm=self.realm.realm,
sync_users=True,
sync_users_password=True,
sync_principal=self.realm.admin_princ,
sync_password=self.realm.password("admin"),
)
self.source.user_property_mappings.set(
KerberosSourcePropertyMapping.objects.filter(
managed__startswith="goauthentik.io/sources/kerberos/user/default/"
)
)
def test_default_mappings(self):
"""Test default mappings"""
KerberosSync(self.source).sync()
self.assertTrue(
User.objects.filter(username=self.realm.user_princ.rsplit("@", 1)[0]).exists()
)
self.assertFalse(
User.objects.filter(username=self.realm.nfs_princ.rsplit("@", 1)[0]).exists()
)
def test_sync_mapping(self):
"""Test property mappings"""
noop = KerberosSourcePropertyMapping.objects.create(
name=generate_id(), expression="return {}"
)
email = KerberosSourcePropertyMapping.objects.create(
name=generate_id(), expression='return {"email": principal.lower()}'
)
dont_sync_service = KerberosSourcePropertyMapping.objects.create(
name=generate_id(),
expression='if "/" in principal:\n return {"username": None}\nreturn {}',
)
self.source.user_property_mappings.set([noop, email, dont_sync_service])
KerberosSync(self.source).sync()
self.assertTrue(
User.objects.filter(username=self.realm.user_princ.rsplit("@", 1)[0]).exists()
)
self.assertEqual(
User.objects.get(username=self.realm.user_princ.rsplit("@", 1)[0]).email,
self.realm.user_princ.lower(),
)
self.assertFalse(
User.objects.filter(username=self.realm.nfs_princ.rsplit("@", 1)[0]).exists()
)
def test_tasks(self):
"""Test Scheduled tasks"""
kerberos_sync_all.delay().get()
self.assertTrue(
User.objects.filter(username=self.realm.user_princ.rsplit("@", 1)[0]).exists()
)

View File

@ -1,40 +0,0 @@
"""Kerberos Source test utils"""
import os
from copy import deepcopy
from time import sleep
from k5test import realm
from rest_framework.test import APITestCase
class KerberosTestCase(APITestCase):
"""Kerberos Test Case"""
@classmethod
def setUpClass(cls):
cls.realm = realm.K5Realm(start_kadmind=True)
cls.realm.http_princ = f"HTTP/testserver@{cls.realm.realm}"
cls.realm.http_keytab = os.path.join(cls.realm.tmpdir, "http_keytab")
cls.realm.addprinc(cls.realm.http_princ)
cls.realm.extract_keytab(cls.realm.http_princ, cls.realm.http_keytab)
cls._saved_env = deepcopy(os.environ)
for k, v in cls.realm.env.items():
os.environ[k] = v
# Wait for everything to start correctly
# Otherwise leads to flaky tests
sleep(5)
@classmethod
def tearDownClass(cls):
cls.realm.stop()
del cls.realm
for k in deepcopy(os.environ):
if k in cls._saved_env:
os.environ[k] = cls._saved_env[k]
else:
del os.environ[k]
cls._saved_env = None

View File

@ -1,22 +0,0 @@
"""Kerberos Source urls"""
from django.urls import path
from authentik.sources.kerberos.api.property_mappings import KerberosSourcePropertyMappingViewSet
from authentik.sources.kerberos.api.source import KerberosSourceViewSet
from authentik.sources.kerberos.api.source_connection import (
GroupKerberosSourceConnectionViewSet,
UserKerberosSourceConnectionViewSet,
)
from authentik.sources.kerberos.views import SPNEGOView
urlpatterns = [
path("<slug:source_slug>/", SPNEGOView.as_view(), name="spnego-login"),
]
api_urlpatterns = [
("propertymappings/source/kerberos", KerberosSourcePropertyMappingViewSet),
("sources/user_connections/kerberos", UserKerberosSourceConnectionViewSet),
("sources/group_connections/kerberos", GroupKerberosSourceConnectionViewSet),
("sources/kerberos", KerberosSourceViewSet),
]

View File

@ -1,181 +0,0 @@
"""Kerberos source SPNEGO views"""
from base64 import b64decode, b64encode
import gssapi
from django.core.cache import cache
from django.core.exceptions import SuspiciousOperation
from django.http import HttpResponse
from django.shortcuts import get_object_or_404, redirect, render, reverse
from django.utils.crypto import get_random_string
from django.utils.translation import gettext_lazy as _
from django.views import View
from structlog.stdlib import get_logger
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.sources.kerberos.models import (
GroupKerberosSourceConnection,
KerberosSource,
Krb5ConfContext,
UserKerberosSourceConnection,
)
LOGGER = get_logger()
SPNEGO_REQUEST_STATUS = 401
WWW_AUTHENTICATE = "WWW-Authenticate"
HTTP_AUTHORIZATION = "Authorization"
NEGOTIATE = "Negotiate"
SPNEGO_STATE_CACHE_PREFIX = "goauthentik.io/sources/spnego"
SPNEGO_STATE_CACHE_TIMEOUT = 60 * 5 # 5 minutes
def add_negotiate_to_response(
response: HttpResponse, token: str | bytes | None = None
) -> HttpResponse:
if isinstance(token, str):
token = token.encode()
response[WWW_AUTHENTICATE] = (
NEGOTIATE if token is None else f"{NEGOTIATE} {b64encode(token).decode('ascii')}"
)
return response
class SPNEGOView(View):
"""SPNEGO login"""
source: KerberosSource
def challenge(self, request, token: str | bytes | None = None) -> HttpResponse:
"""Get SNPEGO challenge response"""
response = render(
request,
"if/error.html",
context={
"title": _("SPNEGO authentication required"),
"message": _(
"""
Make sure you have valid tickets (obtainable via kinit)
and configured the browser correctly.
Please contact your administrator.
"""
),
},
status=401,
)
return add_negotiate_to_response(response, token)
def get_authstr(self, request) -> str | None:
"""Get SPNEGO authentication string from headers"""
authorization_header = request.headers.get(HTTP_AUTHORIZATION, "")
if NEGOTIATE.lower() not in authorization_header.lower():
return None
auth_tuple = authorization_header.split(" ", 1)
if not auth_tuple or auth_tuple[0].lower() != NEGOTIATE.lower():
return None
if len(auth_tuple) != 2: # noqa: PLR2004
raise SuspiciousOperation("Malformed authorization header")
return auth_tuple[1]
def new_state(self) -> str:
"""Generate request state"""
return get_random_string(32)
def get_server_ctx(self, key: str) -> gssapi.sec_contexts.SecurityContext | None:
"""Get GSSAPI server context from cache or create it"""
server_creds = self.source.get_gssapi_creds()
if server_creds is None:
return None
state = cache.get(f"{SPNEGO_STATE_CACHE_PREFIX}/{key}", None)
if state:
# pylint: disable=c-extension-no-member
return gssapi.sec_contexts.SecurityContext(
base=gssapi.raw.sec_contexts.import_sec_context(state),
)
return gssapi.sec_contexts.SecurityContext(creds=server_creds, usage="accept")
def set_server_ctx(self, key: str, ctx: gssapi.sec_contexts.SecurityContext):
"""Store the GSSAPI server context in cache"""
cache.set(f"{SPNEGO_STATE_CACHE_PREFIX}/{key}", ctx.export(), SPNEGO_STATE_CACHE_TIMEOUT)
# pylint: disable=too-many-return-statements
def dispatch(self, request, *args, **kwargs) -> HttpResponse:
"""Process SPNEGO request"""
self.source: KerberosSource = get_object_or_404(
KerberosSource,
slug=kwargs.get("source_slug", ""),
enabled=True,
)
qstring = request.GET if request.method == "GET" else request.POST
state = qstring.get("state", None)
if not state:
return redirect(
reverse(
"authentik_sources_kerberos:spnego-login",
kwargs={"source_slug": self.source.slug},
)
+ f"?state={self.new_state()}"
)
authstr = self.get_authstr(request)
if not authstr:
LOGGER.debug("authstr not present, sending challenge")
return self.challenge(request)
try:
in_token = b64decode(authstr)
except (TypeError, ValueError):
return self.challenge(request)
with Krb5ConfContext(self.source):
server_ctx = self.get_server_ctx(state)
if not server_ctx:
return self.challenge(request)
try:
out_token = server_ctx.step(in_token)
except gssapi.exceptions.GSSError as exc:
LOGGER.debug("GSSAPI security context failure", exc=exc)
return self.challenge(request)
if not server_ctx.complete or server_ctx.initiator_name is None:
self.set_server_ctx(state, server_ctx)
return self.challenge(request, out_token)
def name_to_str(n: gssapi.names.Name) -> str:
return n.display_as(n.name_type)
identifier = name_to_str(server_ctx.initiator_name)
context = {
"spnego_info": {
"initiator_name": name_to_str(server_ctx.initiator_name),
"target_name": name_to_str(server_ctx.target_name),
"mech": str(server_ctx.mech),
"actual_flags": server_ctx.actual_flags,
},
}
response = SPNEGOSourceFlowManager(
source=self.source,
request=request,
identifier=identifier,
user_info={
"principal": identifier,
**context,
},
policy_context=context,
).get_flow()
return add_negotiate_to_response(response, out_token)
class SPNEGOSourceFlowManager(SourceFlowManager):
"""Flow manager for Kerberos SPNEGO sources"""
user_connection_type = UserKerberosSourceConnection
group_connection_type = GroupKerberosSourceConnection

View File

@ -43,7 +43,7 @@ class LDAPBackend(InbuiltBackend):
if source.password_login_update_internal_password:
# Password given successfully binds to LDAP, so we save it in our Database
LOGGER.debug("Updating user's password in DB", user=user)
user.set_password(password, sender=source)
user.set_password(password, signal=False)
user.save()
return user
# Password doesn't match

View File

@ -62,8 +62,6 @@ def ldap_sync_password(sender, user: User, password: str, **_):
if not sources.exists():
return
source = sources.first()
if source.pk == getattr(sender, "pk", None):
return
if not LDAPPasswordChanger.should_check_user(user):
return
try:

View File

@ -1,6 +1,6 @@
# Generated by Django 5.0.9 on 2024-10-10 15:45
from django.db import migrations, models
from django.db import migrations
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
@ -23,22 +23,4 @@ class Migration(migrations.Migration):
operations = [
migrations.RunPython(fix_X509SubjectName),
migrations.AlterField(
model_name="samlsource",
name="name_id_policy",
field=models.TextField(
choices=[
("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", "Email"),
("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent", "Persistent"),
("urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName", "X509"),
(
"urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName",
"Windows",
),
("urn:oasis:names:tc:SAML:2.0:nameid-format:transient", "Transient"),
],
default="urn:oasis:names:tc:SAML:2.0:nameid-format:persistent",
help_text="NameID Policy sent to the IdP. Can be unset, in which case no Policy is sent.",
),
),
]

View File

@ -9,6 +9,7 @@ from rest_framework import mixins
from rest_framework.decorators import action
from rest_framework.fields import CharField, ChoiceField, IntegerField
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAdminUser
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet, ModelViewSet
@ -196,6 +197,7 @@ class DuoDeviceViewSet(
class DuoAdminDeviceViewSet(ModelViewSet):
"""Viewset for Duo authenticator devices (for admins)"""
permission_classes = [IsAdminUser]
queryset = DuoDevice.objects.all()
serializer_class = DuoDeviceSerializer
search_fields = ["name"]

View File

@ -3,6 +3,7 @@
from django_filters.rest_framework.backends import DjangoFilterBackend
from rest_framework import mixins
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAdminUser
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.api.authorization import OwnerFilter, OwnerPermissions
@ -75,6 +76,7 @@ class SMSDeviceViewSet(
class SMSAdminDeviceViewSet(ModelViewSet):
"""Viewset for sms authenticator devices (for admins)"""
permission_classes = [IsAdminUser]
queryset = SMSDevice.objects.all()
serializer_class = SMSDeviceSerializer
search_fields = ["name"]

View File

@ -3,6 +3,7 @@
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import mixins
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAdminUser
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.api.authorization import OwnerFilter, OwnerPermissions
@ -79,6 +80,7 @@ class StaticDeviceViewSet(
class StaticAdminDeviceViewSet(ModelViewSet):
"""Viewset for static authenticator devices (for admins)"""
permission_classes = [IsAdminUser]
queryset = StaticDevice.objects.all()
serializer_class = StaticDeviceSerializer
search_fields = ["name"]

View File

@ -4,6 +4,7 @@ from django_filters.rest_framework.backends import DjangoFilterBackend
from rest_framework import mixins
from rest_framework.fields import ChoiceField
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAdminUser
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.api.authorization import OwnerFilter, OwnerPermissions
@ -71,6 +72,7 @@ class TOTPDeviceViewSet(
class TOTPAdminDeviceViewSet(ModelViewSet):
"""Viewset for totp authenticator devices (for admins)"""
permission_classes = [IsAdminUser]
queryset = TOTPDevice.objects.all()
serializer_class = TOTPDeviceSerializer
search_fields = ["name"]

View File

@ -8,7 +8,7 @@ from django.http.response import Http404
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as __
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField, DateTimeField
from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError
from structlog.stdlib import get_logger
from webauthn import options_to_json
@ -45,7 +45,6 @@ class DeviceChallenge(PassiveSerializer):
device_class = CharField()
device_uid = CharField()
challenge = JSONDictField()
last_used = DateTimeField(allow_null=True)
def get_challenge_for_device(

View File

@ -217,7 +217,6 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device_class": device_class,
"device_uid": device.pk,
"challenge": get_challenge_for_device(self.request, stage, device),
"last_used": device.last_used,
}
)
challenge.is_valid()
@ -238,7 +237,6 @@ class AuthenticatorValidateStageView(ChallengeStageView):
self.request,
self.executor.current_stage,
),
"last_used": None,
}
)
challenge.is_valid()

View File

@ -107,7 +107,6 @@ class AuthenticatorValidateStageSMSTests(FlowTestCase):
"device_class": "sms",
"device_uid": str(device.pk),
"challenge": {},
"last_used": None,
},
},
)

View File

@ -169,7 +169,6 @@ class AuthenticatorValidateStageTests(FlowTestCase):
"device_class": "baz",
"device_uid": "quox",
"challenge": {},
"last_used": None,
}
},
)
@ -189,7 +188,6 @@ class AuthenticatorValidateStageTests(FlowTestCase):
"device_class": "static",
"device_uid": "1",
"challenge": {},
"last_used": None,
},
},
)

Some files were not shown because too many files have changed in this diff Show More