sources: add property mappings for all oauth and saml sources (#8771)

Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
Marc 'risson' Schmitt
2024-08-07 19:14:22 +02:00
committed by GitHub
parent 78bae556d0
commit 83b02a17d5
64 changed files with 3631 additions and 314 deletions

View File

@ -33,6 +33,7 @@ from authentik.blueprints.v1.common import (
from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry
from authentik.core.models import (
AuthenticatedSession,
GroupSourceConnection,
PropertyMapping,
Provider,
Source,
@ -91,6 +92,7 @@ def excluded_models() -> list[type[Model]]:
Source,
PropertyMapping,
UserSourceConnection,
GroupSourceConnection,
Stage,
OutpostServiceConnection,
Policy,

View File

@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import MetaNameSerializer, ModelSerializer
from authentik.core.models import Source, UserSourceConnection
from authentik.core.models import GroupSourceConnection, Source, UserSourceConnection
from authentik.core.types import UserSettingSerializer
from authentik.lib.utils.file import (
FilePathSerializer,
@ -194,3 +194,43 @@ class UserSourceConnectionViewSet(
search_fields = ["source__slug"]
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
ordering = ["source__slug", "pk"]
class GroupSourceConnectionSerializer(SourceSerializer):
"""Group Source Connection Serializer"""
source = SourceSerializer(read_only=True)
class Meta:
model = GroupSourceConnection
fields = [
"pk",
"group",
"source",
"identifier",
"created",
]
extra_kwargs = {
"group": {"read_only": True},
"identifier": {"read_only": True},
"created": {"read_only": True},
}
class GroupSourceConnectionViewSet(
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""Group-source connection Viewset"""
queryset = GroupSourceConnection.objects.all()
serializer_class = GroupSourceConnectionSerializer
permission_classes = [OwnerSuperuserPermissions]
filterset_fields = ["group", "source__slug"]
search_fields = ["source__slug"]
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
ordering = ["source__slug", "pk"]

View File

@ -0,0 +1,67 @@
# Generated by Django 5.0.7 on 2024-08-01 18:52
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0038_source_authentik_c_enabled_d72365_idx"),
]
operations = [
migrations.AddField(
model_name="source",
name="group_matching_mode",
field=models.TextField(
choices=[
("identifier", "Use the source-specific identifier"),
(
"name_link",
"Link to a group with identical name. Can have security implications when a group name is used with another source.",
),
(
"name_deny",
"Use the group name, but deny enrollment when the name already exists.",
),
],
default="identifier",
help_text="How the source determines if an existing group should be used or a new group created.",
),
),
migrations.AlterField(
model_name="group",
name="name",
field=models.TextField(verbose_name="name"),
),
migrations.CreateModel(
name="GroupSourceConnection",
fields=[
(
"id",
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
),
),
("created", models.DateTimeField(auto_now_add=True)),
("last_updated", models.DateTimeField(auto_now=True)),
("identifier", models.TextField()),
(
"group",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_core.group"
),
),
(
"source",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_core.source"
),
),
],
options={
"unique_together": {("group", "source")},
},
),
]

View File

@ -173,7 +173,7 @@ class Group(SerializerModel, AttributesMixin):
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
name = models.CharField(_("name"), max_length=80)
name = models.TextField(_("name"))
is_superuser = models.BooleanField(
default=False, help_text=_("Users added to this group will be superusers.")
)
@ -583,6 +583,19 @@ class SourceUserMatchingModes(models.TextChoices):
)
class SourceGroupMatchingModes(models.TextChoices):
"""Different modes a source can handle new/returning groups"""
IDENTIFIER = "identifier", _("Use the source-specific identifier")
NAME_LINK = "name_link", _(
"Link to a group with identical name. Can have security implications "
"when a group name is used with another source."
)
NAME_DENY = "name_deny", _(
"Use the group name, but deny enrollment when the name already exists."
)
class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
@ -632,6 +645,14 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"a new user enrolled."
),
)
group_matching_mode = models.TextField(
choices=SourceGroupMatchingModes.choices,
default=SourceGroupMatchingModes.IDENTIFIER,
help_text=_(
"How the source determines if an existing group should be used or "
"a new group created."
),
)
objects = InheritanceManager()
@ -727,6 +748,27 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel):
unique_together = (("user", "source"),)
class GroupSourceConnection(SerializerModel, CreatedUpdatedModel):
"""Connection between Group and Source."""
group = models.ForeignKey(Group, on_delete=models.CASCADE)
source = models.ForeignKey(Source, on_delete=models.CASCADE)
identifier = models.TextField()
objects = InheritanceManager()
@property
def serializer(self) -> type[Serializer]:
"""Get serializer for this model"""
raise NotImplementedError
def __str__(self) -> str:
return f"Group-source connection (group={self.group_id}, source={self.source_id})"
class Meta:
unique_together = (("group", "source"),)
class ExpiringModel(models.Model):
"""Base Model which can expire, and is automatically cleaned up."""

View File

@ -4,7 +4,7 @@ from enum import Enum
from typing import Any
from django.contrib import messages
from django.db import IntegrityError
from django.db import IntegrityError, transaction
from django.db.models.query_utils import Q
from django.http import HttpRequest, HttpResponse
from django.shortcuts import redirect
@ -12,8 +12,20 @@ from django.urls import reverse
from django.utils.translation import gettext as _
from structlog.stdlib import get_logger
from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage
from authentik.core.models import (
Group,
GroupSourceConnection,
Source,
SourceGroupMatchingModes,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
from authentik.core.sources.mapper import SourceMapper
from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION,
PostSourceStage,
)
from authentik.events.models import Event, EventAction
from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage
@ -36,7 +48,10 @@ from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
from authentik.stages.user_write.stage import PLAN_CONTEXT_USER_PATH
LOGGER = get_logger()
SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token" # nosec
PLAN_CONTEXT_SOURCE_GROUPS = "source_groups"
class Action(Enum):
@ -70,48 +85,69 @@ class SourceFlowManager:
or deny the request."""
source: Source
mapper: SourceMapper
request: HttpRequest
identifier: str
connection_type: type[UserSourceConnection] = UserSourceConnection
user_connection_type: type[UserSourceConnection] = UserSourceConnection
group_connection_type: type[GroupSourceConnection] = GroupSourceConnection
enroll_info: dict[str, Any]
user_info: dict[str, Any]
policy_context: dict[str, Any]
user_properties: dict[str, Any | dict[str, Any]]
groups_properties: dict[str, dict[str, Any | dict[str, Any]]]
def __init__(
self,
source: Source,
request: HttpRequest,
identifier: str,
enroll_info: dict[str, Any],
user_info: dict[str, Any],
policy_context: dict[str, Any],
) -> None:
self.source = source
self.mapper = SourceMapper(self.source)
self.request = request
self.identifier = identifier
self.enroll_info = enroll_info
self.user_info = user_info
self._logger = get_logger().bind(source=source, identifier=identifier)
self.policy_context = {}
self.policy_context = policy_context
self.user_properties = self.mapper.build_object_properties(
object_type=User, request=request, user=None, **self.user_info
)
self.groups_properties = {
group_id: self.mapper.build_object_properties(
object_type=Group,
request=request,
user=None,
group_id=group_id,
**self.user_info,
)
for group_id in self.user_properties.setdefault("groups", [])
}
del self.user_properties["groups"]
def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911
"""decide which action should be taken"""
new_connection = self.connection_type(source=self.source, identifier=self.identifier)
new_connection = self.user_connection_type(source=self.source, identifier=self.identifier)
# When request is authenticated, always link
if self.request.user.is_authenticated:
new_connection.user = self.request.user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection = self.update_user_connection(new_connection, **kwargs)
return Action.LINK, new_connection
existing_connections = self.connection_type.objects.filter(
existing_connections = self.user_connection_type.objects.filter(
source=self.source, identifier=self.identifier
)
if existing_connections.exists():
connection = existing_connections.first()
return Action.AUTH, self.update_connection(connection, **kwargs)
return Action.AUTH, self.update_user_connection(connection, **kwargs)
# No connection exists, but we match on identifier, so enroll
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
# Check for existing users with matching attributes
query = Q()
@ -120,24 +156,24 @@ class SourceFlowManager:
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.EMAIL_DENY,
]:
if not self.enroll_info.get("email", None):
self._logger.warning("Refusing to use none email", source=self.source)
if not self.user_properties.get("email", None):
self._logger.warning("Refusing to use none email")
return Action.DENY, None
query = Q(email__exact=self.enroll_info.get("email", None))
query = Q(email__exact=self.user_properties.get("email", None))
if self.source.user_matching_mode in [
SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY,
]:
if not self.enroll_info.get("username", None):
self._logger.warning("Refusing to use none username", source=self.source)
if not self.user_properties.get("username", None):
self._logger.warning("Refusing to use none username")
return Action.DENY, None
query = Q(username__exact=self.enroll_info.get("username", None))
query = Q(username__exact=self.user_properties.get("username", None))
self._logger.debug("trying to link with existing user", query=query)
matching_users = User.objects.filter(query)
# No matching users, always enroll
if not matching_users.exists():
self._logger.debug("no matching users found, enrolling")
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
user = matching_users.first()
if self.source.user_matching_mode in [
@ -145,7 +181,7 @@ class SourceFlowManager:
SourceUserMatchingModes.USERNAME_LINK,
]:
new_connection.user = user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection = self.update_user_connection(new_connection, **kwargs)
return Action.LINK, new_connection
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_DENY,
@ -156,10 +192,10 @@ class SourceFlowManager:
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def update_connection(
def update_user_connection(
self, connection: UserSourceConnection, **kwargs
) -> UserSourceConnection: # pragma: no cover
"""Optionally make changes to the connection after it is looked up/created."""
"""Optionally make changes to the user connection after it is looked up/created."""
return connection
def get_flow(self, **kwargs) -> HttpResponse:
@ -215,25 +251,31 @@ class SourceFlowManager:
flow: Flow | None,
connection: UserSourceConnection,
stages: list[StageView] | None = None,
**kwargs,
**flow_context,
) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
kwargs.update(
# Ensure redirect is carried through when user was trying to
# authorize application
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-user"
)
flow_context.update(
{
# Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: BACKEND_INBUILT,
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_SOURCES_CONNECTION: connection,
PLAN_CONTEXT_SOURCE_GROUPS: self.groups_properties,
}
)
kwargs.update(self.policy_context)
flow_context.update(self.policy_context)
if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session:
token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
self._logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
plan = token.plan
plan.context[PLAN_CONTEXT_IS_RESTORED] = token
plan.context.update(kwargs)
plan.context.update(flow_context)
for stage in self.get_stages_to_append(flow):
plan.append_stage(stage)
if stages:
@ -252,8 +294,8 @@ class SourceFlowManager:
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-user"
)
if PLAN_CONTEXT_REDIRECT not in kwargs:
kwargs[PLAN_CONTEXT_REDIRECT] = final_redirect
if PLAN_CONTEXT_REDIRECT not in flow_context:
flow_context[PLAN_CONTEXT_REDIRECT] = final_redirect
if not flow:
return bad_request_message(
@ -265,9 +307,12 @@ class SourceFlowManager:
# We append some stages so the initial flow we get might be empty
planner.allow_empty_flows = True
planner.use_cache = False
plan = planner.plan(self.request, kwargs)
plan = planner.plan(self.request, flow_context)
for stage in self.get_stages_to_append(flow):
plan.append_stage(stage)
plan.append_stage(
in_memory_stage(GroupUpdateStage, group_connection_type=self.group_connection_type)
)
if stages:
for stage in stages:
plan.append_stage(stage)
@ -354,7 +399,123 @@ class SourceFlowManager:
)
],
**{
PLAN_CONTEXT_PROMPT: delete_none_values(self.enroll_info),
PLAN_CONTEXT_PROMPT: delete_none_values(self.user_properties),
PLAN_CONTEXT_USER_PATH: self.source.get_user_path(),
},
)
class GroupUpdateStage(StageView):
"""Dynamically injected stage which updates the user after enrollment/authentication."""
def get_action(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, GroupSourceConnection | None]:
"""decide which action should be taken"""
new_connection = self.group_connection_type(source=self.source, identifier=group_id)
existing_connections = self.group_connection_type.objects.filter(
source=self.source, identifier=group_id
)
if existing_connections.exists():
return Action.LINK, existing_connections.first()
# No connection exists, but we match on identifier, so enroll
if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, new_connection
# Check for existing groups with matching attributes
query = Q()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
SourceGroupMatchingModes.NAME_DENY,
]:
if not group_properties.get("name", None):
LOGGER.warning(
"Refusing to use none group name", source=self.source, group_id=group_id
)
return Action.DENY, None
query = Q(name__exact=group_properties.get("name"))
LOGGER.debug(
"trying to link with existing group", source=self.source, query=query, group_id=group_id
)
matching_groups = Group.objects.filter(query)
# No matching groups, always enroll
if not matching_groups.exists():
LOGGER.debug(
"no matching groups found, enrolling", source=self.source, group_id=group_id
)
return Action.ENROLL, new_connection
group = matching_groups.first()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
]:
new_connection.group = group
return Action.LINK, new_connection
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_DENY,
]:
LOGGER.info(
"denying source because group exists",
source=self.source,
group=group,
group_id=group_id,
)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def handle_group(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> Group | None:
action, connection = self.get_action(group_id, group_properties)
if action == Action.ENROLL:
group = Group.objects.create(**group_properties)
connection.group = group
connection.save()
return group
elif action == Action.LINK:
group = connection.group
group.update_attributes(group_properties)
connection.save()
return group
return None
def handle_groups(self) -> bool:
self.source: Source = self.executor.plan.context[PLAN_CONTEXT_SOURCE]
self.user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
self.group_connection_type: GroupSourceConnection = (
self.executor.current_stage.group_connection_type
)
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
PLAN_CONTEXT_SOURCE_GROUPS
]
groups: list[Group] = []
for group_id, group_properties in raw_groups.items():
group = self.handle_group(group_id, group_properties)
if not group:
return False
groups.append(group)
with transaction.atomic():
self.user.ak_groups.remove(
*self.user.ak_groups.filter(groupsourceconnection__source=self.source)
)
self.user.ak_groups.add(*groups)
return True
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Stage used after the user has been enrolled to sync their groups from source data"""
if self.handle_groups():
return self.executor.stage_ok()
else:
return self.executor.stage_invalid("Failed to update groups. Please try again later.")
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View File

@ -38,7 +38,9 @@ class TestSourceFlowManager(TestCase):
def test_unauthenticated_enroll(self):
"""Test un-authenticated user enrolling"""
request = get_request("/", user=AnonymousUser())
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
flow_manager = OAuthSourceFlowManager(
self.source, request, self.identifier, {"info": {}}, {}
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL)
response = flow_manager.get_flow()
@ -52,7 +54,9 @@ class TestSourceFlowManager(TestCase):
user=get_anonymous_user(), source=self.source, identifier=self.identifier
)
request = get_request("/", user=AnonymousUser())
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
flow_manager = OAuthSourceFlowManager(
self.source, request, self.identifier, {"info": {}}, {}
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.AUTH)
response = flow_manager.get_flow()
@ -64,7 +68,9 @@ class TestSourceFlowManager(TestCase):
"""Test authenticated user linking"""
user = User.objects.create(username="foo", email="foo@bar.baz")
request = get_request("/", user=user)
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
flow_manager = OAuthSourceFlowManager(
self.source, request, self.identifier, {"info": {}}, {}
)
action, connection = flow_manager.get_action()
self.assertEqual(action, Action.LINK)
self.assertIsNone(connection.pk)
@ -77,7 +83,9 @@ class TestSourceFlowManager(TestCase):
def test_unauthenticated_link(self):
"""Test un-authenticated user linking"""
flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {})
flow_manager = OAuthSourceFlowManager(
self.source, get_request("/"), self.identifier, {"info": {}}, {}
)
action, connection = flow_manager.get_action()
self.assertEqual(action, Action.LINK)
self.assertIsNone(connection.pk)
@ -90,7 +98,7 @@ class TestSourceFlowManager(TestCase):
# Without email, deny
flow_manager = OAuthSourceFlowManager(
self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
self.source, get_request("/", user=AnonymousUser()), self.identifier, {"info": {}}, {}
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.DENY)
@ -100,7 +108,12 @@ class TestSourceFlowManager(TestCase):
self.source,
get_request("/", user=AnonymousUser()),
self.identifier,
{"email": "foo@bar.baz"},
{
"info": {
"email": "foo@bar.baz",
},
},
{},
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.LINK)
@ -113,7 +126,7 @@ class TestSourceFlowManager(TestCase):
# Without username, deny
flow_manager = OAuthSourceFlowManager(
self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
self.source, get_request("/", user=AnonymousUser()), self.identifier, {"info": {}}, {}
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.DENY)
@ -123,7 +136,10 @@ class TestSourceFlowManager(TestCase):
self.source,
get_request("/", user=AnonymousUser()),
self.identifier,
{"username": "foo"},
{
"info": {"username": "foo"},
},
{},
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.LINK)
@ -140,8 +156,11 @@ class TestSourceFlowManager(TestCase):
get_request("/", user=AnonymousUser()),
self.identifier,
{
"username": "bar",
"info": {
"username": "bar",
},
},
{},
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL)
@ -151,7 +170,10 @@ class TestSourceFlowManager(TestCase):
self.source,
get_request("/", user=AnonymousUser()),
self.identifier,
{"username": "foo"},
{
"info": {"username": "foo"},
},
{},
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.DENY)
@ -165,7 +187,10 @@ class TestSourceFlowManager(TestCase):
self.source,
get_request("/", user=AnonymousUser()),
self.identifier,
{"username": "foo"},
{
"info": {"username": "foo"},
},
{},
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL)
@ -191,7 +216,10 @@ class TestSourceFlowManager(TestCase):
self.source,
get_request("/", user=AnonymousUser()),
self.identifier,
{"username": "foo"},
{
"info": {"username": "foo"},
},
{},
)
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL)

View File

@ -0,0 +1,237 @@
"""Test Source flow_manager group update stage"""
from django.test import RequestFactory
from authentik.core.models import Group, SourceGroupMatchingModes
from authentik.core.sources.flow_manager import PLAN_CONTEXT_SOURCE_GROUPS, GroupUpdateStage
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import in_memory_stage
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import GroupOAuthSourceConnection, OAuthSource
class TestSourceFlowManager(FlowTestCase):
"""Test Source flow_manager group update stage"""
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.authentication_flow = create_test_flow()
self.enrollment_flow = create_test_flow()
self.source: OAuthSource = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
authentication_flow=self.authentication_flow,
enrollment_flow=self.enrollment_flow,
)
self.identifier = generate_id()
self.user = create_test_admin_user()
def test_nonexistant_group(self):
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(
group=Group.objects.get(name="group 1"), source=self.source
).exists()
)
def test_nonexistant_group_name_link(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK
self.source.save()
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(
group=Group.objects.get(name="group 1"), source=self.source
).exists()
)
def test_existant_group_name_link(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK
self.source.save()
group = Group.objects.create(name="group 1")
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(group=group, source=self.source).exists()
)
def test_nonexistant_group_name_deny(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_DENY
self.source.save()
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(
group=Group.objects.get(name="group 1"), source=self.source
).exists()
)
def test_existant_group_name_deny(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_DENY
self.source.save()
group = Group.objects.create(name="group 1")
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertFalse(stage.handle_groups())
self.assertFalse(self.user.ak_groups.filter(name="group 1").exists())
self.assertFalse(
GroupOAuthSourceConnection.objects.filter(group=group, source=self.source).exists()
)
def test_group_updates(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK
self.source.save()
other_group = Group.objects.create(name="other group")
old_group = Group.objects.create(name="old group")
new_group = Group.objects.create(name="new group")
self.user.ak_groups.set([other_group, old_group])
GroupOAuthSourceConnection.objects.create(
group=old_group, source=self.source, identifier=old_group.name
)
GroupOAuthSourceConnection.objects.create(
group=new_group, source=self.source, identifier=new_group.name
)
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"new group": {
"name": "new group",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertFalse(self.user.ak_groups.filter(name="old group").exists())
self.assertTrue(self.user.ak_groups.filter(name="other group").exists())
self.assertTrue(self.user.ak_groups.filter(name="new group").exists())
self.assertEqual(self.user.ak_groups.count(), 2)

View File

@ -0,0 +1,31 @@
"""OAuth source property mappings API"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.sources.oauth.models import OAuthSourcePropertyMapping
class OAuthSourcePropertyMappingSerializer(PropertyMappingSerializer):
"""OAuthSourcePropertyMapping Serializer"""
class Meta(PropertyMappingSerializer.Meta):
model = OAuthSourcePropertyMapping
class OAuthSourcePropertyMappingFilter(PropertyMappingFilterSet):
"""Filter for OAuthSourcePropertyMapping"""
class Meta(PropertyMappingFilterSet.Meta):
model = OAuthSourcePropertyMapping
class OAuthSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet):
"""OAuthSourcePropertyMapping Viewset"""
queryset = OAuthSourcePropertyMapping.objects.all()
serializer_class = OAuthSourcePropertyMappingSerializer
filterset_class = OAuthSourcePropertyMappingFilter
search_fields = ["name"]
ordering = ["name"]

View File

@ -116,6 +116,7 @@ class OAuthSourceSerializer(SourceSerializer):
class Meta:
model = OAuthSource
fields = SourceSerializer.Meta.fields + [
"group_matching_mode",
"provider_type",
"request_token_url",
"authorization_url",
@ -158,6 +159,7 @@ class OAuthSourceFilter(FilterSet):
"enrollment_flow",
"policy_engine_mode",
"user_matching_mode",
"group_matching_mode",
"provider_type",
"request_token_url",
"authorization_url",

View File

@ -3,10 +3,12 @@
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.sources import (
GroupSourceConnectionSerializer,
GroupSourceConnectionViewSet,
UserSourceConnectionSerializer,
UserSourceConnectionViewSet,
)
from authentik.sources.oauth.models import UserOAuthSourceConnection
from authentik.sources.oauth.models import GroupOAuthSourceConnection, UserOAuthSourceConnection
class UserOAuthSourceConnectionSerializer(UserSourceConnectionSerializer):
@ -26,3 +28,17 @@ class UserOAuthSourceConnectionViewSet(UserSourceConnectionViewSet, ModelViewSet
queryset = UserOAuthSourceConnection.objects.all()
serializer_class = UserOAuthSourceConnectionSerializer
class GroupOAuthSourceConnectionSerializer(GroupSourceConnectionSerializer):
"""OAuth Group-Source connection Serializer"""
class Meta(GroupSourceConnectionSerializer.Meta):
model = GroupOAuthSourceConnection
class GroupOAuthSourceConnectionViewSet(GroupSourceConnectionViewSet, ModelViewSet):
"""Group-source connection Viewset"""
queryset = GroupOAuthSourceConnection.objects.all()
serializer_class = GroupOAuthSourceConnectionSerializer

View File

@ -0,0 +1,60 @@
# Generated by Django 5.0.7 on 2024-08-01 18:52
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"),
(
"authentik_sources_oauth",
"0007_oauthsource_oidc_jwks_oauthsource_oidc_jwks_url_and_more",
),
]
operations = [
migrations.CreateModel(
name="GroupOAuthSourceConnection",
fields=[
(
"groupsourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.groupsourceconnection",
),
),
],
options={
"verbose_name": "Group OAuth Source Connection",
"verbose_name_plural": "Group OAuth Source Connections",
},
bases=("authentik_core.groupsourceconnection",),
),
migrations.CreateModel(
name="OAuthSourcePropertyMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
],
options={
"verbose_name": "OAuth Source Property Mapping",
"verbose_name_plural": "OAuth Source Property Mappings",
},
bases=("authentik_core.propertymapping",),
),
]

View File

@ -9,7 +9,12 @@ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer
from authentik.core.api.object_types import CreatableType, NonCreatableType
from authentik.core.models import Source, UserSourceConnection
from authentik.core.models import (
GroupSourceConnection,
PropertyMapping,
Source,
UserSourceConnection,
)
from authentik.core.types import UILoginButton, UserSettingSerializer
if TYPE_CHECKING:
@ -73,6 +78,16 @@ class OAuthSource(NonCreatableType, Source):
return OAuthSourceSerializer
@property
def property_mapping_type(self) -> type[PropertyMapping]:
return OAuthSourcePropertyMapping
def get_base_user_properties(self, **kwargs):
return self.source_type().get_base_user_properties(source=self, **kwargs)
def get_base_group_properties(self, **kwargs):
return self.source_type().get_base_group_properties(source=self, **kwargs)
@property
def icon_url(self) -> str | None:
# When listing source types, this property might be retrieved from an abstract
@ -248,6 +263,26 @@ class RedditOAuthSource(CreatableType, OAuthSource):
verbose_name_plural = _("Reddit OAuth Sources")
class OAuthSourcePropertyMapping(PropertyMapping):
"""Map OAuth properties to User or Group object attributes"""
@property
def component(self) -> str:
return "ak-property-mapping-oauth-source-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.oauth.api.property_mappings import (
OAuthSourcePropertyMappingSerializer,
)
return OAuthSourcePropertyMappingSerializer
class Meta:
verbose_name = _("OAuth Source Property Mapping")
verbose_name_plural = _("OAuth Source Property Mappings")
class UserOAuthSourceConnection(UserSourceConnection):
"""Authorized remote OAuth provider."""
@ -269,3 +304,19 @@ class UserOAuthSourceConnection(UserSourceConnection):
class Meta:
verbose_name = _("User OAuth Source Connection")
verbose_name_plural = _("User OAuth Source Connections")
class GroupOAuthSourceConnection(GroupSourceConnection):
"""Group-source connection"""
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.oauth.api.source_connection import (
GroupOAuthSourceConnectionSerializer,
)
return GroupOAuthSourceConnectionSerializer
class Meta:
verbose_name = _("Group OAuth Source Connection")
verbose_name_plural = _("Group OAuth Source Connections")

View File

@ -0,0 +1,109 @@
"""Apple Type tests"""
from copy import deepcopy
from django.contrib.auth.models import AnonymousUser
from django.test import TestCase
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import get_request
from authentik.sources.oauth.models import OAuthSource, OAuthSourcePropertyMapping
from authentik.sources.oauth.views.callback import OAuthSourceFlowManager
INFO = {
"sub": "83692",
"name": "Alice Adams",
"email": "alice@example.com",
"department": "Engineering",
"birthdate": "1975-12-31",
"nickname": "foo",
}
IDENTIFIER = INFO["sub"]
class TestPropertyMappings(TestCase):
"""OAuth Source tests"""
def setUp(self):
self.source = OAuthSource.objects.create(
name="test",
slug="test",
provider_type="openidconnect",
authorization_url="",
profile_url="",
consumer_key=generate_id(),
)
def test_user_base_properties(self):
"""Test user base properties"""
properties = self.source.get_base_user_properties(info=INFO)
self.assertEqual(
properties,
{
"email": "alice@example.com",
"groups": [],
"name": "Alice Adams",
"username": "foo",
},
)
def test_group_base_properties(self):
"""Test group base properties"""
info = deepcopy(INFO)
info["groups"] = ["group 1", "group 2"]
properties = self.source.get_base_user_properties(info=info)
self.assertEqual(properties["groups"], ["group 1", "group 2"])
for group_id in info["groups"]:
properties = self.source.get_base_group_properties(info=info, group_id=group_id)
self.assertEqual(properties, {"name": group_id})
def test_user_property_mappings(self):
self.source.user_property_mappings.add(
OAuthSourcePropertyMapping.objects.create(
name="test",
expression="return {'attributes': {'department': info.get('department')}}",
)
)
request = get_request("/", user=AnonymousUser())
flow_manager = OAuthSourceFlowManager(self.source, request, IDENTIFIER, {"info": INFO}, {})
self.assertEqual(
flow_manager.user_properties,
{
"attributes": {
"department": "Engineering",
},
"email": "alice@example.com",
"name": "Alice Adams",
"username": "foo",
"path": self.source.get_user_path(),
},
)
def test_grup_property_mappings(self):
info = deepcopy(INFO)
info["groups"] = ["group 1", "group 2"]
self.source.group_property_mappings.add(
OAuthSourcePropertyMapping.objects.create(
name="test",
expression="return {'attributes': {'id': group_id}}",
)
)
request = get_request("/", user=AnonymousUser())
flow_manager = OAuthSourceFlowManager(self.source, request, IDENTIFIER, {"info": info}, {})
self.assertEqual(
flow_manager.groups_properties,
{
"group 1": {
"name": "group 1",
"attributes": {
"id": "group 1",
},
},
"group 2": {
"name": "group 2",
"attributes": {
"id": "group 2",
},
},
},
)

View File

@ -3,7 +3,7 @@
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback
from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback, AzureADType
# https://docs.microsoft.com/en-us/graph/api/user-get?view=graph-rest-1.0&tabs=http#response-2
AAD_USER = {
@ -41,7 +41,7 @@ class TestTypeAzureAD(TestCase):
def test_enroll_context(self):
"""Test azure_ad Enrollment context"""
ak_context = AzureADOAuthCallback().get_user_enroll_context(AAD_USER)
ak_context = AzureADType().get_base_user_properties(source=self.source, info=AAD_USER)
self.assertEqual(ak_context["username"], AAD_USER["userPrincipalName"])
self.assertEqual(ak_context["email"], AAD_USER["mail"])
self.assertEqual(ak_context["name"], AAD_USER["displayName"])

View File

@ -3,7 +3,7 @@
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.discord import DiscordOAuth2Callback
from authentik.sources.oauth.types.discord import DiscordType
# https://discord.com/developers/docs/resources/user#user-object
DISCORD_USER = {
@ -34,7 +34,7 @@ class TestTypeDiscord(TestCase):
def test_enroll_context(self):
"""Test discord Enrollment context"""
ak_context = DiscordOAuth2Callback().get_user_enroll_context(DISCORD_USER)
ak_context = DiscordType().get_base_user_properties(source=self.source, info=DISCORD_USER)
self.assertEqual(ak_context["username"], DISCORD_USER["username"])
self.assertEqual(ak_context["email"], DISCORD_USER["email"])
self.assertEqual(ak_context["name"], DISCORD_USER["username"])

View File

@ -7,7 +7,10 @@ from requests_mock import Mocker
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.github import GitHubOAuth2Callback
from authentik.sources.oauth.types.github import (
GitHubOAuth2Callback,
GitHubType,
)
# https://developer.github.com/v3/users/#get-the-authenticated-user
GITHUB_USER = {
@ -66,7 +69,9 @@ class TestTypeGitHub(TestCase):
def test_enroll_context(self):
"""Test GitHub Enrollment context"""
ak_context = GitHubOAuth2Callback().get_user_enroll_context(GITHUB_USER)
ak_context = GitHubType().get_base_user_properties(
source=self.source, info=GITHUB_USER, client=None, token={}
)
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], GITHUB_USER["email"])
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
@ -86,14 +91,18 @@ class TestTypeGitHub(TestCase):
}
],
)
ak_context = GitHubOAuth2Callback(
token = {
"access_token": generate_id(),
"token_type": generate_id(),
}
callback = GitHubOAuth2Callback(
source=self.source,
request=self.factory.get("/"),
token={
"access_token": generate_id(),
"token_type": generate_id(),
},
).get_user_enroll_context(user)
token=token,
)
ak_context = GitHubType().get_base_user_properties(
source=self.source, info=user, client=callback.get_client(self.source), token=token
)
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], email)
self.assertEqual(ak_context["name"], GITHUB_USER["name"])

View File

@ -3,7 +3,7 @@
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.gitlab import GitLabOAuthCallback
from authentik.sources.oauth.types.gitlab import GitLabType
GITLAB_USER = {
"preferred_username": "dev_gitlab",
@ -24,7 +24,7 @@ class TestTypeGitLab(TestCase):
def test_enroll_context(self):
"""Test GitLab Enrollment context"""
ak_context = GitLabOAuthCallback().get_user_enroll_context(GITLAB_USER)
ak_context = GitLabType().get_base_user_properties(source=self.source, info=GITLAB_USER)
self.assertEqual(ak_context["username"], GITLAB_USER["preferred_username"])
self.assertEqual(ak_context["email"], GITLAB_USER["email"])
self.assertEqual(ak_context["name"], GITLAB_USER["name"])

View File

@ -6,7 +6,10 @@ from django.test.client import RequestFactory
from authentik.lib.tests.utils import dummy_get_response
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.google import GoogleOAuth2Callback, GoogleOAuthRedirect
from authentik.sources.oauth.types.google import (
GoogleOAuthRedirect,
GoogleType,
)
# https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en
GOOGLE_USER = {
@ -37,7 +40,7 @@ class TestTypeGoogle(TestCase):
def test_enroll_context(self):
"""Test Google Enrollment context"""
ak_context = GoogleOAuth2Callback().get_user_enroll_context(GOOGLE_USER)
ak_context = GoogleType().get_base_user_properties(source=self.source, info=GOOGLE_USER)
self.assertEqual(ak_context["email"], GOOGLE_USER["email"])
self.assertEqual(ak_context["name"], GOOGLE_USER["name"])

View File

@ -3,7 +3,7 @@
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.mailcow import MailcowOAuth2Callback
from authentik.sources.oauth.types.mailcow import MailcowType
# https://community.mailcow.email/d/13-mailcow-oauth-json-format/2
MAILCOW_USER = {
@ -34,6 +34,6 @@ class TestTypeMailcow(TestCase):
def test_enroll_context(self):
"""Test mailcow Enrollment context"""
ak_context = MailcowOAuth2Callback().get_user_enroll_context(MAILCOW_USER)
ak_context = MailcowType().get_base_user_properties(source=self.source, info=MAILCOW_USER)
self.assertEqual(ak_context["email"], MAILCOW_USER["email"])
self.assertEqual(ak_context["name"], MAILCOW_USER["full_name"])

View File

@ -5,7 +5,7 @@ from requests_mock import Mocker
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback, OpenIDConnectType
# https://connect2id.com/products/server/docs/api/userinfo
OPENID_USER = {
@ -34,7 +34,9 @@ class TestTypeOpenID(TestCase):
def test_enroll_context(self):
"""Test OpenID Enrollment context"""
ak_context = OpenIDConnectOAuth2Callback().get_user_enroll_context(OPENID_USER)
ak_context = OpenIDConnectType().get_base_user_properties(
source=self.source, info=OPENID_USER
)
self.assertEqual(ak_context["username"], OPENID_USER["nickname"])
self.assertEqual(ak_context["email"], OPENID_USER["email"])
self.assertEqual(ak_context["name"], OPENID_USER["name"])

View File

@ -1,9 +1,9 @@
"""Patreon Type tests"""
from django.test import RequestFactory, TestCase
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.patreon import PatreonOAuthCallback
from authentik.sources.oauth.types.patreon import PatreonType
PATREON_USER = {
"data": {
@ -58,11 +58,10 @@ class TestTypePatreon(TestCase):
slug="test",
provider_type="Patreon",
)
self.factory = RequestFactory()
def test_enroll_context(self):
"""Test Patreon Enrollment context"""
ak_context = PatreonOAuthCallback().get_user_enroll_context(PATREON_USER)
ak_context = PatreonType().get_base_user_properties(source=self.source, info=PATREON_USER)
self.assertEqual(ak_context["username"], PATREON_USER["data"]["attributes"]["vanity"])
self.assertEqual(ak_context["email"], PATREON_USER["data"]["attributes"]["email"])
self.assertEqual(ak_context["name"], PATREON_USER["data"]["attributes"]["full_name"])

View File

@ -3,7 +3,7 @@
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.twitch import TwitchOAuth2Callback
from authentik.sources.oauth.types.twitch import TwitchType
# https://dev.twitch.tv/docs/authentication/getting-tokens-oidc/#getting-claims-information-from-an-access-token
TWITCH_USER = {
@ -32,7 +32,7 @@ class TestTypeTwitch(TestCase):
def test_enroll_context(self):
"""Test twitch Enrollment context"""
ak_context = TwitchOAuth2Callback().get_user_enroll_context(TWITCH_USER)
ak_context = TwitchType().get_base_user_properties(source=self.source, info=TWITCH_USER)
self.assertEqual(ak_context["username"], TWITCH_USER["preferred_username"])
self.assertEqual(ak_context["email"], TWITCH_USER["email"])
self.assertEqual(ak_context["name"], TWITCH_USER["preferred_username"])

View File

@ -3,7 +3,7 @@
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.twitter import TwitterOAuthCallback
from authentik.sources.oauth.types.twitter import TwitterType
# https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me
TWITTER_USER = {"data": {"id": "2244994945", "name": "TwitterDev", "username": "Twitter Dev"}}
@ -24,7 +24,7 @@ class TestTypeGitHub(TestCase):
def test_enroll_context(self):
"""Test Twitter Enrollment context"""
ak_context = TwitterOAuthCallback().get_user_enroll_context(TWITTER_USER)
ak_context = TwitterType().get_base_user_properties(source=self.source, info=TWITTER_USER)
self.assertEqual(ak_context["username"], TWITTER_USER["data"]["username"])
self.assertEqual(ak_context["email"], None)
self.assertEqual(ak_context["name"], TWITTER_USER["data"]["name"])

View File

@ -90,15 +90,6 @@ class AppleOAuth2Callback(OAuthCallback):
def get_user_id(self, info: dict[str, Any]) -> str | None:
return info["sub"]
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class AppleType(SourceType):
@ -132,3 +123,9 @@ class AppleType(SourceType):
"state": args["state"],
}
)
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"email": info.get("email"),
"name": info.get("name"),
}

View File

@ -31,17 +31,6 @@ class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
# fallback to OpenID logic in case the profile URL was changed
return info.get("id", super().get_user_id(info))
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
return {
"username": info.get("userPrincipalName"),
"email": mail,
"name": info.get("displayName"),
}
@registry.register()
class AzureADType(SourceType):
@ -61,3 +50,11 @@ class AzureADType(SourceType):
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
)
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
return {
"username": info.get("userPrincipalName"),
"email": mail,
"name": info.get("displayName"),
}

View File

@ -20,16 +20,6 @@ class DiscordOAuthRedirect(OAuthRedirect):
class DiscordOAuth2Callback(OAuthCallback):
"""Discord OAuth2 Callback"""
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("username"),
"email": info.get("email", None),
"name": info.get("username"),
}
@registry.register()
class DiscordType(SourceType):
@ -43,3 +33,10 @@ class DiscordType(SourceType):
authorization_url = "https://discord.com/api/oauth2/authorize"
access_token_url = "https://discord.com/api/oauth2/token" # nosec
profile_url = "https://discord.com/api/users/@me"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("username"),
"email": info.get("email", None),
"name": info.get("username"),
}

View File

@ -19,16 +19,6 @@ class FacebookOAuthRedirect(OAuthRedirect):
class FacebookOAuth2Callback(OAuthCallback):
"""Facebook OAuth2 Callback"""
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("name"),
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class FacebookType(SourceType):
@ -42,3 +32,10 @@ class FacebookType(SourceType):
authorization_url = "https://www.facebook.com/v7.0/dialog/oauth"
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("name"),
"email": info.get("email"),
"name": info.get("name"),
}

View File

@ -5,6 +5,7 @@ from typing import Any
from requests.exceptions import RequestException
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -42,26 +43,6 @@ class GitHubOAuth2Callback(OAuthCallback):
client_class = GitHubOAuth2Client
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
chosen_email = info.get("email")
if not chosen_email:
# The GitHub Userprofile API only returns an email address if the profile
# has a public email address set (despite us asking for user:email, this behaviour
# doesn't change.). So we fetch all the user's email addresses
client: GitHubOAuth2Client = self.get_client(self.source)
emails = client.get_github_emails(self.token)
for email in emails:
if email.get("primary", False):
chosen_email = email.get("email", None)
return {
"username": info.get("login"),
"email": chosen_email,
"name": info.get("name"),
}
@registry.register()
class GitHubType(SourceType):
@ -81,3 +62,26 @@ class GitHubType(SourceType):
"https://token.actions.githubusercontent.com/.well-known/openid-configuration"
)
oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks"
def get_base_user_properties(
self,
source: OAuthSource,
client: GitHubOAuth2Client,
token: dict[str, str],
info: dict[str, Any],
**kwargs,
) -> dict[str, Any]:
chosen_email = info.get("email")
if not chosen_email:
# The GitHub Userprofile API only returns an email address if the profile
# has a public email address set (despite us asking for user:email, this behaviour
# doesn't change.). So we fetch all the user's email addresses
emails = client.get_github_emails(token)
for email in emails:
if email.get("primary", False):
chosen_email = email.get("email", None)
return {
"username": info.get("login"),
"email": chosen_email,
"name": info.get("name"),
}

View File

@ -25,16 +25,6 @@ class GitLabOAuthRedirect(OAuthRedirect):
class GitLabOAuthCallback(OAuthCallback):
"""GitLab OAuth2 Callback"""
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("preferred_username"),
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class GitLabType(SourceType):
@ -52,3 +42,10 @@ class GitLabType(SourceType):
profile_url = "https://gitlab.com/oauth/userinfo"
oidc_well_known_url = "https://gitlab.com/.well-known/openid-configuration"
oidc_jwks_url = "https://gitlab.com/oauth/discovery/keys"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("preferred_username"),
"email": info.get("email"),
"name": info.get("name"),
}

View File

@ -19,15 +19,6 @@ class GoogleOAuthRedirect(OAuthRedirect):
class GoogleOAuth2Callback(OAuthCallback):
"""Google OAuth2 Callback"""
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class GoogleType(SourceType):
@ -43,3 +34,9 @@ class GoogleType(SourceType):
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"
oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration"
oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"email": info.get("email"),
"name": info.get("name"),
}

View File

@ -47,16 +47,6 @@ class MailcowOAuth2Callback(OAuthCallback):
client_class = MailcowOAuth2Client
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("full_name"),
"email": info.get("email"),
"name": info.get("full_name"),
}
@registry.register()
class MailcowType(SourceType):
@ -68,3 +58,10 @@ class MailcowType(SourceType):
name = "mailcow"
urls_customizable = True
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("full_name"),
"email": info.get("email"),
"name": info.get("full_name"),
}

View File

@ -26,16 +26,6 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", None)
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("nickname", info.get("preferred_username")),
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class OpenIDConnectType(SourceType):
@ -47,3 +37,11 @@ class OpenIDConnectType(SourceType):
name = "openidconnect"
urls_customizable = True
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("nickname", info.get("preferred_username")),
"email": info.get("email"),
"name": info.get("name"),
"groups": info.get("groups", []),
}

View File

@ -26,16 +26,6 @@ class OktaOAuth2Callback(OpenIDConnectOAuth2Callback):
# see https://github.com/goauthentik/authentik/issues/1910
client_class = UserprofileHeaderAuthClient
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("nickname"),
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class OktaType(SourceType):
@ -47,3 +37,11 @@ class OktaType(SourceType):
name = "okta"
urls_customizable = True
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("nickname"),
"email": info.get("email"),
"name": info.get("name"),
"groups": info.get("groups", []),
}

View File

@ -27,16 +27,6 @@ class PatreonOAuthCallback(OAuthCallback):
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("data", {}).get("id")
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("data", {}).get("attributes", {}).get("vanity"),
"email": info.get("data", {}).get("attributes", {}).get("email"),
"name": info.get("data", {}).get("attributes", {}).get("full_name"),
}
@registry.register()
class PatreonType(SourceType):
@ -50,3 +40,10 @@ class PatreonType(SourceType):
authorization_url = "https://www.patreon.com/oauth2/authorize"
access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec
profile_url = "https://www.patreon.com/api/oauth2/api/current_user"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("data", {}).get("attributes", {}).get("vanity"),
"email": info.get("data", {}).get("attributes", {}).get("email"),
"name": info.get("data", {}).get("attributes", {}).get("full_name"),
}

View File

@ -34,17 +34,6 @@ class RedditOAuth2Callback(OAuthCallback):
client_class = RedditOAuth2Client
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("name"),
"email": None,
"name": info.get("name"),
"password": None,
}
@registry.register()
class RedditType(SourceType):
@ -58,3 +47,10 @@ class RedditType(SourceType):
authorization_url = "https://www.reddit.com/api/v1/authorize"
access_token_url = "https://www.reddit.com/api/v1/access_token" # nosec
profile_url = "https://oauth.reddit.com/api/v1/me"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("name"),
"email": None,
"name": info.get("name"),
}

View File

@ -2,6 +2,7 @@
from collections.abc import Callable
from enum import Enum
from typing import Any
from django.http.request import HttpRequest
from django.templatetags.static import static
@ -55,6 +56,20 @@ class SourceType:
}
)
def get_base_user_properties(
self, source: OAuthSource, info: dict[str, Any], **kwargs
) -> dict[str, Any | dict[str, Any]]:
"""Get base user properties for enrollment/update"""
return info
def get_base_group_properties(
self, source: OAuthSource, group_id: str, **kwargs
) -> dict[str, Any | dict[str, Any]]:
"""Get base group properties for creation/update"""
return {
"name": group_id,
}
class SourceTypeRegistry:
"""Registry to hold all Source types."""

View File

@ -33,16 +33,6 @@ class TwitchOAuth2Callback(OpenIDConnectOAuth2Callback):
client_class = TwitchClient
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("preferred_username"),
"email": info.get("email"),
"name": info.get("preferred_username"),
}
@registry.register()
class TwitchType(SourceType):
@ -56,3 +46,10 @@ class TwitchType(SourceType):
authorization_url = "https://id.twitch.tv/oauth2/authorize"
access_token_url = "https://id.twitch.tv/oauth2/token" # nosec
profile_url = "https://id.twitch.tv/oauth2/userinfo"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("preferred_username"),
"email": info.get("email"),
"name": info.get("preferred_username"),
}

View File

@ -49,17 +49,6 @@ class TwitterOAuthCallback(OAuthCallback):
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("data", {}).get("id", "")
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
data = info.get("data", {})
return {
"username": data.get("username"),
"email": None,
"name": data.get("name"),
}
@registry.register()
class TwitterType(SourceType):
@ -73,3 +62,11 @@ class TwitterType(SourceType):
authorization_url = "https://twitter.com/i/oauth2/authorize"
access_token_url = "https://api.twitter.com/2/oauth2/token" # nosec
profile_url = "https://api.twitter.com/2/users/me"
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
data = info.get("data", {})
return {
"username": data.get("username"),
"email": None,
"name": data.get("name"),
}

View File

@ -2,8 +2,12 @@
from django.urls import path
from authentik.sources.oauth.api.property_mappings import OAuthSourcePropertyMappingViewSet
from authentik.sources.oauth.api.source import OAuthSourceViewSet
from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet
from authentik.sources.oauth.api.source_connection import (
GroupOAuthSourceConnectionViewSet,
UserOAuthSourceConnectionViewSet,
)
from authentik.sources.oauth.types.registry import RequestKind
from authentik.sources.oauth.views.dispatcher import DispatcherView
@ -21,6 +25,8 @@ urlpatterns = [
]
api_urlpatterns = [
("propertymappings/source/oauth", OAuthSourcePropertyMappingViewSet),
("sources/user_connections/oauth", UserOAuthSourceConnectionViewSet),
("sources/group_connections/oauth", GroupOAuthSourceConnectionViewSet),
("sources/oauth", OAuthSourceViewSet),
]

View File

@ -13,7 +13,11 @@ from structlog.stdlib import get_logger
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.events.models import Event, EventAction
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.models import (
GroupOAuthSourceConnection,
OAuthSource,
UserOAuthSourceConnection,
)
from authentik.sources.oauth.views.base import OAuthClientMixin
LOGGER = get_logger()
@ -57,15 +61,19 @@ class OAuthCallback(OAuthClientMixin, View):
identifier = self.get_user_id(info=raw_info)
if identifier is None:
return self.handle_login_failure("Could not determine id.")
# Get or create access record
enroll_info = self.get_user_enroll_context(raw_info)
sfm = OAuthSourceFlowManager(
source=self.source,
request=self.request,
identifier=identifier,
enroll_info=enroll_info,
user_info={
"info": raw_info,
"client": client,
"token": self.token,
},
policy_context={
"oauth_userinfo": raw_info,
},
)
sfm.policy_context = {"oauth_userinfo": raw_info}
return sfm.get_flow(
raw_info=raw_info,
access_token=self.token.get("access_token"),
@ -79,13 +87,6 @@ class OAuthCallback(OAuthClientMixin, View):
"Return url to redirect on login failure."
return settings.LOGIN_URL
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
"""Create a dict of User data"""
raise NotImplementedError()
def get_user_id(self, info: dict[str, Any]) -> str | None:
"""Return unique identifier from the profile info."""
if "id" in info:
@ -111,9 +112,10 @@ class OAuthCallback(OAuthClientMixin, View):
class OAuthSourceFlowManager(SourceFlowManager):
"""Flow manager for oauth sources"""
connection_type = UserOAuthSourceConnection
user_connection_type = UserOAuthSourceConnection
group_connection_type = GroupOAuthSourceConnection
def update_connection(
def update_user_connection(
self,
connection: UserOAuthSourceConnection,
access_token: str | None = None,

View File

@ -109,7 +109,8 @@ class PlexSourceViewSet(UsedByMixin, ModelViewSet):
source=source,
request=request,
identifier=str(identifier),
enroll_info=user_info,
user_info=user_info,
policy_context={},
)
return to_stage_response(request, sfm.get_flow(plex_token=plex_token))
LOGGER.warning(

View File

@ -113,9 +113,11 @@ class PlexAuth:
class PlexSourceFlowManager(SourceFlowManager):
"""Flow manager for plex sources"""
connection_type = PlexSourceConnection
user_connection_type = PlexSourceConnection
def update_connection(self, connection: PlexSourceConnection, **kwargs) -> PlexSourceConnection:
def update_user_connection(
self, connection: PlexSourceConnection, **kwargs
) -> PlexSourceConnection:
"""Set the access_token on the connection"""
connection.plex_token = kwargs.get("plex_token")
return connection

View File

@ -0,0 +1,31 @@
"""SAML source property mappings API"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.sources.saml.models import SAMLSourcePropertyMapping
class SAMLSourcePropertyMappingSerializer(PropertyMappingSerializer):
"""SAMLSourcePropertyMapping Serializer"""
class Meta(PropertyMappingSerializer.Meta):
model = SAMLSourcePropertyMapping
class SAMLSourcePropertyMappingFilter(PropertyMappingFilterSet):
"""Filter for SAMLSourcePropertyMapping"""
class Meta(PropertyMappingFilterSet.Meta):
model = SAMLSourcePropertyMapping
class SAMLSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet):
"""SAMLSourcePropertyMapping Viewset"""
queryset = SAMLSourcePropertyMapping.objects.all()
serializer_class = SAMLSourcePropertyMappingSerializer
filterset_class = SAMLSourcePropertyMappingFilter
search_fields = ["name"]
ordering = ["name"]

View File

@ -20,6 +20,7 @@ class SAMLSourceSerializer(SourceSerializer):
class Meta:
model = SAMLSource
fields = SourceSerializer.Meta.fields + [
"group_matching_mode",
"pre_authentication_flow",
"issuer",
"sso_url",

View File

@ -3,10 +3,12 @@
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.sources import (
GroupSourceConnectionSerializer,
GroupSourceConnectionViewSet,
UserSourceConnectionSerializer,
UserSourceConnectionViewSet,
)
from authentik.sources.saml.models import UserSAMLSourceConnection
from authentik.sources.saml.models import GroupSAMLSourceConnection, UserSAMLSourceConnection
class UserSAMLSourceConnectionSerializer(UserSourceConnectionSerializer):
@ -22,3 +24,17 @@ class UserSAMLSourceConnectionViewSet(UserSourceConnectionViewSet, ModelViewSet)
queryset = UserSAMLSourceConnection.objects.all()
serializer_class = UserSAMLSourceConnectionSerializer
class GroupSAMLSourceConnectionSerializer(GroupSourceConnectionSerializer):
"""OAuth Group-Source connection Serializer"""
class Meta(GroupSourceConnectionSerializer.Meta):
model = GroupSAMLSourceConnection
class GroupSAMLSourceConnectionViewSet(GroupSourceConnectionViewSet):
"""Group-source connection Viewset"""
queryset = GroupSAMLSourceConnection.objects.all()
serializer_class = GroupSAMLSourceConnectionSerializer

View File

@ -0,0 +1,57 @@
# Generated by Django 5.0.7 on 2024-08-01 18:52
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"),
("authentik_sources_saml", "0014_alter_samlsource_digest_algorithm_and_more"),
]
operations = [
migrations.CreateModel(
name="GroupSAMLSourceConnection",
fields=[
(
"groupsourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.groupsourceconnection",
),
),
],
options={
"verbose_name": "Group SAML Source Connection",
"verbose_name_plural": "Group SAML Source Connections",
},
bases=("authentik_core.groupsourceconnection",),
),
migrations.CreateModel(
name="SAMLSourcePropertyMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
],
options={
"verbose_name": "SAML Source Property Mapping",
"verbose_name_plural": "SAML Source Property Mappings",
},
bases=("authentik_core.propertymapping",),
),
]

View File

@ -1,5 +1,7 @@
"""saml sp models"""
from typing import Any
from django.db import models
from django.http import HttpRequest
from django.templatetags.static import static
@ -7,11 +9,17 @@ from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer
from authentik.core.models import Source, UserSourceConnection
from authentik.core.models import (
GroupSourceConnection,
PropertyMapping,
Source,
UserSourceConnection,
)
from authentik.core.types import UILoginButton, UserSettingSerializer
from authentik.crypto.models import CertificateKeyPair
from authentik.flows.challenge import RedirectChallenge
from authentik.flows.models import Flow
from authentik.lib.expression.evaluator import BaseEvaluator
from authentik.lib.utils.time import timedelta_string_validator
from authentik.sources.saml.processors.constants import (
DSA_SHA1,
@ -19,10 +27,12 @@ from authentik.sources.saml.processors.constants import (
ECDSA_SHA256,
ECDSA_SHA384,
ECDSA_SHA512,
NS_SAML_ASSERTION,
RSA_SHA1,
RSA_SHA256,
RSA_SHA384,
RSA_SHA512,
SAML_ATTRIBUTES_GROUP,
SAML_BINDING_POST,
SAML_BINDING_REDIRECT,
SAML_NAME_ID_FORMAT_EMAIL,
@ -182,11 +192,39 @@ class SAMLSource(Source):
return SAMLSourceSerializer
@property
def icon_url(self) -> str:
icon = super().icon_url
if not icon:
return static("authentik/sources/saml.png")
return icon
def property_mapping_type(self) -> type[PropertyMapping]:
return SAMLSourcePropertyMapping
def get_base_user_properties(self, root: Any, name_id: Any, **kwargs):
attributes = {}
assertion = root.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
if assertion is None:
raise ValueError("Assertion element not found")
attribute_statement = assertion.find(f"{{{NS_SAML_ASSERTION}}}AttributeStatement")
if attribute_statement is None:
raise ValueError("Attribute statement element not found")
# Get all attributes and their values into a dict
for attribute in attribute_statement.iterchildren():
key = attribute.attrib["Name"]
attributes.setdefault(key, [])
for value in attribute.iterchildren():
attributes[key].append(value.text)
if SAML_ATTRIBUTES_GROUP in attributes:
attributes["groups"] = attributes[SAML_ATTRIBUTES_GROUP]
del attributes[SAML_ATTRIBUTES_GROUP]
# Flatten all lists in the dict
for key, value in attributes.items():
if key == "groups":
continue
attributes[key] = BaseEvaluator.expr_flatten(value)
attributes["username"] = name_id.text
return attributes
def get_base_group_properties(self, group_id: str, **kwargs):
return {
"name": group_id,
}
def get_issuer(self, request: HttpRequest) -> str:
"""Get Source's Issuer, falling back to our Metadata URL if none is set"""
@ -200,6 +238,13 @@ class SAMLSource(Source):
reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug})
)
@property
def icon_url(self) -> str:
icon = super().icon_url
if not icon:
return static("authentik/sources/saml.png")
return icon
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
return UILoginButton(
challenge=RedirectChallenge(
@ -235,6 +280,24 @@ class SAMLSource(Source):
verbose_name_plural = _("SAML Sources")
class SAMLSourcePropertyMapping(PropertyMapping):
"""Map SAML properties to User or Group object attributes"""
@property
def component(self) -> str:
return "ak-property-mapping-saml-source-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.saml.api.property_mappings import SAMLSourcePropertyMappingSerializer
return SAMLSourcePropertyMappingSerializer
class Meta:
verbose_name = _("SAML Source Property Mapping")
verbose_name_plural = _("SAML Source Property Mappings")
class UserSAMLSourceConnection(UserSourceConnection):
"""Connection to configured SAML Sources."""
@ -249,3 +312,19 @@ class UserSAMLSourceConnection(UserSourceConnection):
class Meta:
verbose_name = _("User SAML Source Connection")
verbose_name_plural = _("User SAML Source Connections")
class GroupSAMLSourceConnection(GroupSourceConnection):
"""Group-source connection"""
@property
def serializer(self) -> type[Serializer]:
from authentik.sources.saml.api.source_connection import (
GroupSAMLSourceConnectionSerializer,
)
return GroupSAMLSourceConnectionSerializer
class Meta:
verbose_name = _("Group SAML Source Connection")
verbose_name_plural = _("Group SAML Source Connections")

View File

@ -21,6 +21,8 @@ SAML_NAME_ID_FORMAT_X509 = "urn:oasis:names:tc:SAML:2.0:nameid-format:X509Subjec
SAML_NAME_ID_FORMAT_WINDOWS = "urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName"
SAML_NAME_ID_FORMAT_TRANSIENT = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
SAML_ATTRIBUTES_GROUP = "http://schemas.xmlsoap.org/claims/Group"
SAML_BINDING_POST = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"

View File

@ -21,16 +21,18 @@ from authentik.core.models import (
User,
)
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.lib.expression.evaluator import BaseEvaluator
from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.utils import delete_none_values
from authentik.sources.saml.exceptions import (
InvalidSignature,
MismatchedRequestID,
MissingSAMLResponse,
UnsupportedNameIDFormat,
)
from authentik.sources.saml.models import SAMLSource, UserSAMLSourceConnection
from authentik.sources.saml.models import (
GroupSAMLSourceConnection,
SAMLSource,
UserSAMLSourceConnection,
)
from authentik.sources.saml.processors.constants import (
NS_MAP,
NS_SAML_ASSERTION,
@ -138,12 +140,12 @@ class ResponseProcessor:
user has an attribute that refers to our Source for cleanup. The user is also deleted
on logout and periodically."""
# Create a temporary User
name_id = self._get_name_id().text
name_id = self._get_name_id()
expiry = mktime(
(now() + timedelta_from_string(self._source.temporary_user_delete_after)).timetuple()
)
user: User = User.objects.create(
username=name_id,
username=name_id.text,
attributes={
USER_ATTRIBUTE_GENERATED: True,
USER_ATTRIBUTE_SOURCES: [
@ -154,15 +156,21 @@ class ResponseProcessor:
},
path=self._source.get_user_path(),
)
LOGGER.debug("Created temporary user for NameID Transient", username=name_id)
LOGGER.debug("Created temporary user for NameID Transient", username=name_id.text)
user.set_unusable_password()
user.save()
UserSAMLSourceConnection.objects.create(source=self._source, user=user, identifier=name_id)
UserSAMLSourceConnection.objects.create(
source=self._source, user=user, identifier=name_id.text
)
return SAMLSourceFlowManager(
self._source,
self._http_request,
name_id,
delete_none_values(self.get_attributes()),
source=self._source,
request=self._http_request,
identifier=str(name_id.text),
user_info={
"root": self._root,
"name_id": name_id,
},
policy_context={},
)
def _get_name_id(self) -> "Element":
@ -200,27 +208,6 @@ class ResponseProcessor:
f"Assertion contains NameID with unsupported format {_format}."
)
def get_attributes(self) -> dict[str, list[str] | str]:
"""Get all attributes sent"""
attributes = {}
assertion = self._root.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
if assertion is None:
raise ValueError("Assertion element not found")
attribute_statement = assertion.find(f"{{{NS_SAML_ASSERTION}}}AttributeStatement")
if attribute_statement is None:
raise ValueError("Attribute statement element not found")
# Get all attributes and their values into a dict
for attribute in attribute_statement.iterchildren():
key = attribute.attrib["Name"]
attributes.setdefault(key, [])
for value in attribute.iterchildren():
attributes[key].append(value.text)
# Flatten all lists in the dict
for key, value in attributes.items():
attributes[key] = BaseEvaluator.expr_flatten(value)
attributes["username"] = self._get_name_id().text
return attributes
def prepare_flow_manager(self) -> SourceFlowManager:
"""Prepare flow plan depending on whether or not the user exists"""
name_id = self._get_name_id()
@ -235,17 +222,22 @@ class ResponseProcessor:
if name_id.attrib["Format"] == SAML_NAME_ID_FORMAT_TRANSIENT:
return self._handle_name_id_transient()
flow_manager = SAMLSourceFlowManager(
self._source,
self._http_request,
name_id.text,
delete_none_values(self.get_attributes()),
return SAMLSourceFlowManager(
source=self._source,
request=self._http_request,
identifier=str(name_id.text),
user_info={
"root": self._root,
"name_id": name_id,
},
policy_context={
"saml_response": etree.tostring(self._root),
},
)
flow_manager.policy_context["saml_response"] = etree.tostring(self._root)
return flow_manager
class SAMLSourceFlowManager(SourceFlowManager):
"""Source flow manager for SAML Sources"""
connection_type = UserSAMLSourceConnection
user_connection_type = UserSAMLSourceConnection
group_connection_type = GroupSAMLSourceConnection

View File

@ -0,0 +1,46 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<saml2p:Response xmlns:saml2p="urn:oasis:names:tc:SAML:2.0:protocol" Destination="https://127.0.0.1:9443/source/saml/google/acs/" ID="_1e17063957f10819a5a8e147971fec22" InResponseTo="_157fb504b59f4ae3919f74896a6b8565" IssueInstant="2022-10-14T14:11:49.590Z" Version="2.0">
<saml2:Issuer xmlns:saml2="urn:oasis:names:tc:SAML:2.0:assertion">https://accounts.google.com/o/saml2?idpid=</saml2:Issuer>
<saml2p:Status>
<saml2p:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"></saml2p:StatusCode>
</saml2p:Status>
<saml2:Assertion xmlns:saml2="urn:oasis:names:tc:SAML:2.0:assertion" ID="_346001c5708ffd118c40edbc0c72fc60" IssueInstant="2022-10-14T14:11:49.590Z" Version="2.0">
<saml2:Issuer>https://accounts.google.com/o/saml2?idpid=</saml2:Issuer>
<saml2:Subject>
<saml2:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:persistent">jens@goauthentik.io</saml2:NameID>
<saml2:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
<saml2:SubjectConfirmationData InResponseTo="_157fb504b59f4ae3919f74896a6b8565" NotOnOrAfter="2022-10-14T14:16:49.590Z" Recipient="https://127.0.0.1:9443/source/saml/google/acs/"></saml2:SubjectConfirmationData>
</saml2:SubjectConfirmation>
</saml2:Subject>
<saml2:Conditions NotBefore="2022-10-14T14:06:49.590Z" NotOnOrAfter="2022-10-14T14:16:49.590Z">
<saml2:AudienceRestriction>
<saml2:Audience>https://accounts.google.com/o/saml2?idpid=</saml2:Audience>
</saml2:AudienceRestriction>
</saml2:Conditions>
<saml2:AttributeStatement>
<saml2:Attribute Name="name">
<saml2:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="xs:anyType">foo</saml2:AttributeValue>
</saml2:Attribute>
<saml2:Attribute Name="sn">
<saml2:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="xs:anyType">bar</saml2:AttributeValue>
</saml2:Attribute>
<saml2:Attribute Name="email">
<saml2:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="xs:anyType">foo@bar.baz</saml2:AttributeValue>
</saml2:Attribute>
<saml2:Attribute Name="http://schemas.xmlsoap.org/claims/Group">
<saml2:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="xs:anyType">group 1</saml2:AttributeValue>
<saml2:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="xs:anyType">group 2</saml2:AttributeValue>
</saml2:Attribute>
</saml2:AttributeStatement>
<saml2:AuthnStatement AuthnInstant="2022-10-14T12:16:21.000Z" SessionIndex="_346001c5708ffd118c40edbc0c72fc60">
<saml2:AuthnContext>
<saml2:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified</saml2:AuthnContextClassRef>
</saml2:AuthnContext>
</saml2:AuthnStatement>
</saml2:Assertion>
</saml2p:Response>

View File

@ -0,0 +1,135 @@
"""SAML Source tests"""
from base64 import b64encode
from defusedxml.lxml import fromstring
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import RequestFactory, TestCase
from authentik.core.tests.utils import create_test_flow
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import dummy_get_response, load_fixture
from authentik.sources.saml.models import SAMLSource, SAMLSourcePropertyMapping
from authentik.sources.saml.processors.constants import NS_SAML_ASSERTION
from authentik.sources.saml.processors.response import ResponseProcessor
ROOT = fromstring(load_fixture("fixtures/response_success.xml").encode())
ROOT_GROUPS = fromstring(load_fixture("fixtures/response_success_groups.xml").encode())
NAME_ID = (
ROOT.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
.find(f"{{{NS_SAML_ASSERTION}}}Subject")
.find(f"{{{NS_SAML_ASSERTION}}}NameID")
)
class TestPropertyMappings(TestCase):
"""Test Property Mappings"""
def setUp(self):
self.factory = RequestFactory()
self.source = SAMLSource.objects.create(
slug=generate_id(),
issuer="authentik",
allow_idp_initiated=True,
pre_authentication_flow=create_test_flow(),
)
def test_user_base_properties(self):
"""Test user base properties"""
properties = self.source.get_base_user_properties(root=ROOT, name_id=NAME_ID)
self.assertEqual(
properties,
{
"email": "foo@bar.baz",
"name": "foo",
"sn": "bar",
"username": "jens@goauthentik.io",
},
)
def test_group_base_properties(self):
"""Test group base properties"""
properties = self.source.get_base_user_properties(root=ROOT_GROUPS, name_id=NAME_ID)
self.assertEqual(properties["groups"], ["group 1", "group 2"])
for group_id in ["group 1", "group 2"]:
properties = self.source.get_base_group_properties(root=ROOT, group_id=group_id)
self.assertEqual(properties, {"name": group_id})
def test_user_property_mappings(self):
"""Test user property mappings"""
self.source.user_property_mappings.add(
SAMLSourcePropertyMapping.objects.create(
name="test",
expression="return {'attributes': {'department': 'Engineering'}, 'sn': None}",
)
)
request = self.factory.post(
"/",
data={
"SAMLResponse": b64encode(
load_fixture("fixtures/response_success.xml").encode()
).decode()
},
)
middleware = SessionMiddleware(dummy_get_response)
middleware.process_request(request)
request.session.save()
parser = ResponseProcessor(self.source, request)
parser.parse()
sfm = parser.prepare_flow_manager()
self.assertEqual(
sfm.user_properties,
{
"email": "foo@bar.baz",
"name": "foo",
"username": "jens@goauthentik.io",
"attributes": {
"department": "Engineering",
},
"path": self.source.get_user_path(),
},
)
def test_group_property_mappings(self):
"""Test group property mappings"""
self.source.group_property_mappings.add(
SAMLSourcePropertyMapping.objects.create(
name="test",
expression="return {'attributes': {'id': group_id}}",
)
)
request = self.factory.post(
"/",
data={
"SAMLResponse": b64encode(
load_fixture("fixtures/response_success_groups.xml").encode()
).decode()
},
)
middleware = SessionMiddleware(dummy_get_response)
middleware.process_request(request)
request.session.save()
parser = ResponseProcessor(self.source, request)
parser.parse()
sfm = parser.prepare_flow_manager()
self.assertEqual(
sfm.groups_properties,
{
"group 1": {
"name": "group 1",
"attributes": {
"id": "group 1",
},
},
"group 2": {
"name": "group 2",
"attributes": {
"id": "group 2",
},
},
},
)

View File

@ -67,6 +67,13 @@ class TestResponseProcessor(TestCase):
parser.parse()
sfm = parser.prepare_flow_manager()
self.assertEqual(
sfm.enroll_info,
{"email": "foo@bar.baz", "name": "foo", "sn": "bar", "username": "jens@goauthentik.io"},
sfm.user_properties,
{
"email": "foo@bar.baz",
"name": "foo",
"sn": "bar",
"username": "jens@goauthentik.io",
"attributes": {},
"path": self.source.get_user_path(),
},
)

View File

@ -2,8 +2,12 @@
from django.urls import path
from authentik.sources.saml.api.property_mappings import SAMLSourcePropertyMappingViewSet
from authentik.sources.saml.api.source import SAMLSourceViewSet
from authentik.sources.saml.api.source_connection import UserSAMLSourceConnectionViewSet
from authentik.sources.saml.api.source_connection import (
GroupSAMLSourceConnectionViewSet,
UserSAMLSourceConnectionViewSet,
)
from authentik.sources.saml.views import ACSView, InitiateView, MetadataView, SLOView
urlpatterns = [
@ -14,6 +18,8 @@ urlpatterns = [
]
api_urlpatterns = [
("propertymappings/source/saml", SAMLSourcePropertyMappingViewSet),
("sources/user_connections/saml", UserSAMLSourceConnectionViewSet),
("sources/group_connections/saml", GroupSAMLSourceConnectionViewSet),
("sources/saml", SAMLSourceViewSet),
]

View File

@ -1201,6 +1201,46 @@
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_oauth.oauthsourcepropertymapping"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_oauth.oauthsourcepropertymapping_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_oauth.oauthsourcepropertymapping"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_oauth.oauthsourcepropertymapping"
}
}
},
{
"type": "object",
"required": [
@ -1241,6 +1281,46 @@
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_oauth.groupoauthsourceconnection"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_oauth.groupoauthsourceconnection_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_oauth.groupoauthsourceconnection"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_oauth.groupoauthsourceconnection"
}
}
},
{
"type": "object",
"required": [
@ -1361,6 +1441,46 @@
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_saml.samlsourcepropertymapping"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_saml.samlsourcepropertymapping_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_saml.samlsourcepropertymapping"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_saml.samlsourcepropertymapping"
}
}
},
{
"type": "object",
"required": [
@ -1401,6 +1521,46 @@
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_sources_saml.groupsamlsourceconnection"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"permissions": {
"$ref": "#/$defs/model_authentik_sources_saml.groupsamlsourceconnection_permissions"
},
"attrs": {
"$ref": "#/$defs/model_authentik_sources_saml.groupsamlsourceconnection"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_sources_saml.groupsamlsourceconnection"
}
}
},
{
"type": "object",
"required": [
@ -4106,11 +4266,15 @@
"authentik_sources_ldap.ldapsource",
"authentik_sources_ldap.ldapsourcepropertymapping",
"authentik_sources_oauth.oauthsource",
"authentik_sources_oauth.oauthsourcepropertymapping",
"authentik_sources_oauth.useroauthsourceconnection",
"authentik_sources_oauth.groupoauthsourceconnection",
"authentik_sources_plex.plexsource",
"authentik_sources_plex.plexsourceconnection",
"authentik_sources_saml.samlsource",
"authentik_sources_saml.samlsourcepropertymapping",
"authentik_sources_saml.usersamlsourceconnection",
"authentik_sources_saml.groupsamlsourceconnection",
"authentik_sources_scim.scimsource",
"authentik_sources_scim.scimsourcepropertymapping",
"authentik_stages_authenticator_duo.authenticatorduostage",
@ -6615,6 +6779,16 @@
"minLength": 1,
"title": "Icon"
},
"group_matching_mode": {
"type": "string",
"enum": [
"identifier",
"name_link",
"name_deny"
],
"title": "Group matching mode",
"description": "How the source determines if an existing group should be used or a new group created."
},
"provider_type": {
"type": "string",
"enum": [
@ -6727,6 +6901,57 @@
}
}
},
"model_authentik_sources_oauth.oauthsourcepropertymapping": {
"type": "object",
"properties": {
"managed": {
"type": [
"string",
"null"
],
"minLength": 1,
"title": "Managed by authentik",
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
},
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"expression": {
"type": "string",
"minLength": 1,
"title": "Expression"
}
},
"required": []
},
"model_authentik_sources_oauth.oauthsourcepropertymapping_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_oauthsourcepropertymapping",
"change_oauthsourcepropertymapping",
"delete_oauthsourcepropertymapping",
"view_oauthsourcepropertymapping"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_sources_oauth.useroauthsourceconnection": {
"type": "object",
"properties": {
@ -6777,6 +7002,43 @@
}
}
},
"model_authentik_sources_oauth.groupoauthsourceconnection": {
"type": "object",
"properties": {
"icon": {
"type": "string",
"minLength": 1,
"title": "Icon"
}
},
"required": []
},
"model_authentik_sources_oauth.groupoauthsourceconnection_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_groupoauthsourceconnection",
"change_groupoauthsourceconnection",
"delete_groupoauthsourceconnection",
"view_groupoauthsourceconnection"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_sources_plex.plexsource": {
"type": "object",
"properties": {
@ -7038,6 +7300,16 @@
"minLength": 1,
"title": "Icon"
},
"group_matching_mode": {
"type": "string",
"enum": [
"identifier",
"name_link",
"name_deny"
],
"title": "Group matching mode",
"description": "How the source determines if an existing group should be used or a new group created."
},
"pre_authentication_flow": {
"type": "string",
"format": "uuid",
@ -7165,6 +7437,57 @@
}
}
},
"model_authentik_sources_saml.samlsourcepropertymapping": {
"type": "object",
"properties": {
"managed": {
"type": [
"string",
"null"
],
"minLength": 1,
"title": "Managed by authentik",
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
},
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"expression": {
"type": "string",
"minLength": 1,
"title": "Expression"
}
},
"required": []
},
"model_authentik_sources_saml.samlsourcepropertymapping_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_samlsourcepropertymapping",
"change_samlsourcepropertymapping",
"delete_samlsourcepropertymapping",
"view_samlsourcepropertymapping"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_sources_saml.usersamlsourceconnection": {
"type": "object",
"properties": {
@ -7207,6 +7530,43 @@
}
}
},
"model_authentik_sources_saml.groupsamlsourceconnection": {
"type": "object",
"properties": {
"icon": {
"type": "string",
"minLength": 1,
"title": "Icon"
}
},
"required": []
},
"model_authentik_sources_saml.groupsamlsourceconnection_permissions": {
"type": "array",
"items": {
"type": "object",
"required": [
"permission"
],
"properties": {
"permission": {
"type": "string",
"enum": [
"add_groupsamlsourceconnection",
"change_groupsamlsourceconnection",
"delete_groupsamlsourceconnection",
"view_groupsamlsourceconnection"
]
},
"user": {
"type": "integer"
},
"role": {
"type": "string"
}
}
}
},
"model_authentik_sources_scim.scimsource": {
"type": "object",
"properties": {
@ -10969,7 +11329,6 @@
"properties": {
"name": {
"type": "string",
"maxLength": 80,
"minLength": 1,
"title": "Name"
},

1382
schema.yml

File diff suppressed because it is too large Load Diff

View File

@ -25,16 +25,6 @@ class OAuth1Callback(OAuthCallback):
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("id")
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("screen_name"),
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class OAUth1Type(SourceType):
@ -50,6 +40,13 @@ class OAUth1Type(SourceType):
profile_url = "http://localhost:5001/api/me"
urls_customizable = False
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
return {
"username": info.get("screen_name"),
"email": info.get("email"),
"name": info.get("name"),
}
class TestSourceOAuth1(SeleniumTestCase):
"""Test OAuth1 Source"""

View File

@ -57,7 +57,9 @@ export class PropertyMappingLDAPSourceForm extends BasePropertyMappingForm<LDAPS
<a
target="_blank"
rel="noopener noreferrer"
href="${docLink("/docs/property-mappings/expression?utm_source=authentik")}"
href="${docLink(
"/docs/sources/property-mappings/expression?utm_source=authentik",
)}"
>
${msg("See documentation for a list of all variables.")}
</a>

View File

@ -2,9 +2,11 @@ import "@goauthentik/admin/property-mappings/PropertyMappingGoogleWorkspaceForm"
import "@goauthentik/admin/property-mappings/PropertyMappingLDAPSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingMicrosoftEntraForm";
import "@goauthentik/admin/property-mappings/PropertyMappingNotification";
import "@goauthentik/admin/property-mappings/PropertyMappingOAuthSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingRACForm";
import "@goauthentik/admin/property-mappings/PropertyMappingRadiusForm";
import "@goauthentik/admin/property-mappings/PropertyMappingSAMLForm";
import "@goauthentik/admin/property-mappings/PropertyMappingSAMLSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingSCIMForm";
import "@goauthentik/admin/property-mappings/PropertyMappingSCIMSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingScopeForm";

View File

@ -0,0 +1,75 @@
import { BasePropertyMappingForm } from "@goauthentik/admin/property-mappings/BasePropertyMappingForm";
import { DEFAULT_CONFIG } from "@goauthentik/common/api/config";
import { docLink } from "@goauthentik/common/global";
import "@goauthentik/elements/CodeMirror";
import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror";
import "@goauthentik/elements/forms/HorizontalFormElement";
import { msg } from "@lit/localize";
import { TemplateResult, html } from "lit";
import { customElement } from "lit/decorators.js";
import { ifDefined } from "lit/directives/if-defined.js";
import { OAuthSourcePropertyMapping, PropertymappingsApi } from "@goauthentik/api";
@customElement("ak-property-mapping-oauth-source-form")
export class PropertyMappingOAuthSourceForm extends BasePropertyMappingForm<OAuthSourcePropertyMapping> {
loadInstance(pk: string): Promise<OAuthSourcePropertyMapping> {
return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceOauthRetrieve({
pmUuid: pk,
});
}
async send(data: OAuthSourcePropertyMapping): Promise<OAuthSourcePropertyMapping> {
if (this.instance) {
return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceOauthUpdate({
pmUuid: this.instance.pk,
oAuthSourcePropertyMappingRequest: data,
});
} else {
return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceOauthCreate({
oAuthSourcePropertyMappingRequest: data,
});
}
}
renderForm(): TemplateResult {
return html` <ak-form-element-horizontal label=${msg("Name")} ?required=${true} name="name">
<input
type="text"
value="${ifDefined(this.instance?.name)}"
class="pf-c-form-control"
required
/>
</ak-form-element-horizontal>
<ak-form-element-horizontal
label=${msg("Expression")}
?required=${true}
name="expression"
>
<ak-codemirror
mode=${CodeMirrorMode.Python}
value="${ifDefined(this.instance?.expression)}"
>
</ak-codemirror>
<p class="pf-c-form__helper-text">
${msg("Expression using Python.")}
<a
target="_blank"
rel="noopener noreferrer"
href="${docLink(
"/docs/sources/property-mappings/expression?utm_source=authentik",
)}"
>
${msg("See documentation for a list of all variables.")}
</a>
</p>
</ak-form-element-horizontal>`;
}
}
declare global {
interface HTMLElementTagNameMap {
"ak-property-mapping-oauth-source-form": PropertyMappingOAuthSourceForm;
}
}

View File

@ -0,0 +1,75 @@
import { BasePropertyMappingForm } from "@goauthentik/admin/property-mappings/BasePropertyMappingForm";
import { DEFAULT_CONFIG } from "@goauthentik/common/api/config";
import { docLink } from "@goauthentik/common/global";
import "@goauthentik/elements/CodeMirror";
import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror";
import "@goauthentik/elements/forms/HorizontalFormElement";
import { msg } from "@lit/localize";
import { TemplateResult, html } from "lit";
import { customElement } from "lit/decorators.js";
import { ifDefined } from "lit/directives/if-defined.js";
import { PropertymappingsApi, SAMLSourcePropertyMapping } from "@goauthentik/api";
@customElement("ak-property-mapping-saml-source-form")
export class PropertyMappingSAMLSourceForm extends BasePropertyMappingForm<SAMLSourcePropertyMapping> {
loadInstance(pk: string): Promise<SAMLSourcePropertyMapping> {
return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceSamlRetrieve({
pmUuid: pk,
});
}
async send(data: SAMLSourcePropertyMapping): Promise<SAMLSourcePropertyMapping> {
if (this.instance) {
return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceSamlUpdate({
pmUuid: this.instance.pk,
sAMLSourcePropertyMappingRequest: data,
});
} else {
return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceSamlCreate({
sAMLSourcePropertyMappingRequest: data,
});
}
}
renderForm(): TemplateResult {
return html` <ak-form-element-horizontal label=${msg("Name")} ?required=${true} name="name">
<input
type="text"
value="${ifDefined(this.instance?.name)}"
class="pf-c-form-control"
required
/>
</ak-form-element-horizontal>
<ak-form-element-horizontal
label=${msg("Expression")}
?required=${true}
name="expression"
>
<ak-codemirror
mode=${CodeMirrorMode.Python}
value="${ifDefined(this.instance?.expression)}"
>
</ak-codemirror>
<p class="pf-c-form__helper-text">
${msg("Expression using Python.")}
<a
target="_blank"
rel="noopener noreferrer"
href="${docLink(
"/docs/sources/property-mappings/expression?utm_source=authentik",
)}"
>
${msg("See documentation for a list of all variables.")}
</a>
</p>
</ak-form-element-horizontal>`;
}
}
declare global {
interface HTMLElementTagNameMap {
"ak-property-mapping-saml-source-form": PropertyMappingSAMLSourceForm;
}
}

View File

@ -1,7 +1,9 @@
import "@goauthentik/admin/property-mappings/PropertyMappingLDAPSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingNotification";
import "@goauthentik/admin/property-mappings/PropertyMappingOAuthSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingRACForm";
import "@goauthentik/admin/property-mappings/PropertyMappingSAMLForm";
import "@goauthentik/admin/property-mappings/PropertyMappingSAMLSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingSCIMSourceForm";
import "@goauthentik/admin/property-mappings/PropertyMappingScopeForm";
import "@goauthentik/admin/property-mappings/PropertyMappingTestForm";

View File

@ -1,7 +1,10 @@
import "@goauthentik/admin/common/ak-flow-search/ak-source-flow-search";
import { iconHelperText, placeholderHelperText } from "@goauthentik/admin/helperText";
import { BaseSourceForm } from "@goauthentik/admin/sources/BaseSourceForm";
import { UserMatchingModeToLabel } from "@goauthentik/admin/sources/oauth/utils";
import {
GroupMatchingModeToLabel,
UserMatchingModeToLabel,
} from "@goauthentik/admin/sources/oauth/utils";
import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config";
import { first } from "@goauthentik/common/utils";
import "@goauthentik/elements/CodeMirror";
@ -10,6 +13,8 @@ import {
CapabilitiesEnum,
WithCapabilitiesConfig,
} from "@goauthentik/elements/Interface/capabilitiesProvider";
import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js";
import { DualSelectPair } from "@goauthentik/elements/ak-dual-select/types.js";
import "@goauthentik/elements/forms/FormGroup";
import "@goauthentik/elements/forms/HorizontalFormElement";
import "@goauthentik/elements/forms/SearchSelect";
@ -21,14 +26,39 @@ import { ifDefined } from "lit/directives/if-defined.js";
import {
FlowsInstancesListDesignationEnum,
GroupMatchingModeEnum,
OAuthSource,
OAuthSourcePropertyMapping,
OAuthSourceRequest,
PropertymappingsApi,
ProviderTypeEnum,
SourceType,
SourcesApi,
UserMatchingModeEnum,
} from "@goauthentik/api";
async function propertyMappingsProvider(page = 1, search = "") {
const propertyMappings = await new PropertymappingsApi(
DEFAULT_CONFIG,
).propertymappingsSourceOauthList({
ordering: "managed",
pageSize: 20,
search: search.trim(),
page,
});
return {
pagination: propertyMappings.pagination,
options: propertyMappings.results.map((m) => [m.pk, m.name, m.name, m]),
};
}
function makePropertyMappingsSelector(instanceMappings?: string[]) {
const localMappings = instanceMappings ? new Set(instanceMappings) : undefined;
return localMappings
? ([pk, _]: DualSelectPair) => localMappings.has(pk)
: ([_0, _1, _2, _]: DualSelectPair<OAuthSourcePropertyMapping>) => false;
}
@customElement("ak-source-oauth-form")
export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuthSource>) {
async loadInstance(pk: string): Promise<OAuthSource> {
@ -40,6 +70,8 @@ export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuth
return source;
}
_modelName?: string;
@property()
modelName?: string;
@ -299,6 +331,35 @@ export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuth
</option>
</select>
</ak-form-element-horizontal>
<ak-form-element-horizontal
label=${msg("Group matching mode")}
?required=${true}
name="groupMatchingMode"
>
<select class="pf-c-form-control">
<option
value=${GroupMatchingModeEnum.Identifier}
?selected=${this.instance?.groupMatchingMode ===
GroupMatchingModeEnum.Identifier}
>
${UserMatchingModeToLabel(UserMatchingModeEnum.Identifier)}
</option>
<option
value=${GroupMatchingModeEnum.NameLink}
?selected=${this.instance?.groupMatchingMode ===
GroupMatchingModeEnum.NameLink}
>
${GroupMatchingModeToLabel(GroupMatchingModeEnum.NameLink)}
</option>
<option
value=${GroupMatchingModeEnum.NameDeny}
?selected=${this.instance?.groupMatchingMode ===
GroupMatchingModeEnum.NameDeny}
>
${GroupMatchingModeToLabel(GroupMatchingModeEnum.NameDeny)}
</option>
</select>
</ak-form-element-horizontal>
<ak-form-element-horizontal label=${msg("User path")} name="userPathTemplate">
<input
type="text"
@ -397,6 +458,43 @@ export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm<OAuth
</div>
</ak-form-group>
${this.renderUrlOptions()}
<ak-form-group ?expanded=${true}>
<span slot="header"> ${msg("OAuth Attribute mapping")} </span>
<div slot="body" class="pf-c-form">
<ak-form-element-horizontal
label=${msg("User Property Mappings")}
name="userPropertyMappings"
>
<ak-dual-select-dynamic-selected
.provider=${propertyMappingsProvider}
.selector=${makePropertyMappingsSelector(
this.instance?.userPropertyMappings,
)}
available-label="${msg("Available User Property Mappings")}"
selected-label="${msg("Selected User Property Mappings")}"
></ak-dual-select-dynamic-selected>
<p class="pf-c-form__helper-text">
${msg("Property mappings for user creation.")}
</p>
</ak-form-element-horizontal>
<ak-form-element-horizontal
label=${msg("Group Property Mappings")}
name="groupPropertyMappings"
>
<ak-dual-select-dynamic-selected
.provider=${propertyMappingsProvider}
.selector=${makePropertyMappingsSelector(
this.instance?.groupPropertyMappings,
)}
available-label="${msg("Available Group Property Mappings")}"
selected-label="${msg("Selected Group Property Mappings")}"
></ak-dual-select-dynamic-selected>
<p class="pf-c-form__helper-text">
${msg("Property mappings for group creation.")}
</p>
</ak-form-element-horizontal>
</div>
</ak-form-group>
<ak-form-group>
<span slot="header"> ${msg("Flow settings")} </span>
<div slot="body" class="pf-c-form">

View File

@ -1,6 +1,6 @@
import { msg } from "@lit/localize";
import { UserMatchingModeEnum } from "@goauthentik/api";
import { GroupMatchingModeEnum, UserMatchingModeEnum } from "@goauthentik/api";
export function UserMatchingModeToLabel(mode?: UserMatchingModeEnum): string {
if (!mode) return "";
@ -27,3 +27,19 @@ export function UserMatchingModeToLabel(mode?: UserMatchingModeEnum): string {
return msg("Unknown user matching mode");
}
}
export function GroupMatchingModeToLabel(mode?: GroupMatchingModeEnum): string {
if (!mode) return "";
switch (mode) {
case GroupMatchingModeEnum.Identifier:
return msg("Link users on unique identifier");
case GroupMatchingModeEnum.NameLink:
return msg(
"Link to a group with identical name. Can have security implications when a group is used with another source",
);
case GroupMatchingModeEnum.NameDeny:
return msg("Use the group's name, but deny enrollment when the name already exists");
case UserMatchingModeEnum.UnknownDefaultOpenApi:
return msg("Unknown user matching mode");
}
}

View File

@ -2,13 +2,18 @@ import "@goauthentik/admin/common/ak-crypto-certificate-search";
import "@goauthentik/admin/common/ak-flow-search/ak-source-flow-search";
import { iconHelperText, placeholderHelperText } from "@goauthentik/admin/helperText";
import { BaseSourceForm } from "@goauthentik/admin/sources/BaseSourceForm";
import { UserMatchingModeToLabel } from "@goauthentik/admin/sources/oauth/utils";
import {
GroupMatchingModeToLabel,
UserMatchingModeToLabel,
} from "@goauthentik/admin/sources/oauth/utils";
import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config";
import { first } from "@goauthentik/common/utils";
import {
CapabilitiesEnum,
WithCapabilitiesConfig,
} from "@goauthentik/elements/Interface/capabilitiesProvider";
import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js";
import { DualSelectPair } from "@goauthentik/elements/ak-dual-select/types.js";
import "@goauthentik/elements/forms/FormGroup";
import "@goauthentik/elements/forms/HorizontalFormElement";
import "@goauthentik/elements/forms/Radio";
@ -23,13 +28,38 @@ import {
BindingTypeEnum,
DigestAlgorithmEnum,
FlowsInstancesListDesignationEnum,
GroupMatchingModeEnum,
NameIdPolicyEnum,
PropertymappingsApi,
SAMLSource,
SAMLSourcePropertyMapping,
SignatureAlgorithmEnum,
SourcesApi,
UserMatchingModeEnum,
} from "@goauthentik/api";
async function propertyMappingsProvider(page = 1, search = "") {
const propertyMappings = await new PropertymappingsApi(
DEFAULT_CONFIG,
).propertymappingsSourceSamlList({
ordering: "managed",
pageSize: 20,
search: search.trim(),
page,
});
return {
pagination: propertyMappings.pagination,
options: propertyMappings.results.map((m) => [m.pk, m.name, m.name, m]),
};
}
function makePropertyMappingsSelector(instanceMappings?: string[]) {
const localMappings = instanceMappings ? new Set(instanceMappings) : undefined;
return localMappings
? ([pk, _]: DualSelectPair) => localMappings.has(pk)
: ([_0, _1, _2, _]: DualSelectPair<SAMLSourcePropertyMapping>) => false;
}
@customElement("ak-source-saml-form")
export class SAMLSourceForm extends WithCapabilitiesConfig(BaseSourceForm<SAMLSource>) {
@state()
@ -151,6 +181,35 @@ export class SAMLSourceForm extends WithCapabilitiesConfig(BaseSourceForm<SAMLSo
</option>
</select>
</ak-form-element-horizontal>
<ak-form-element-horizontal
label=${msg("Group matching mode")}
?required=${true}
name="groupMatchingMode"
>
<select class="pf-c-form-control">
<option
value=${GroupMatchingModeEnum.Identifier}
?selected=${this.instance?.groupMatchingMode ===
GroupMatchingModeEnum.Identifier}
>
${UserMatchingModeToLabel(UserMatchingModeEnum.Identifier)}
</option>
<option
value=${GroupMatchingModeEnum.NameLink}
?selected=${this.instance?.groupMatchingMode ===
GroupMatchingModeEnum.NameLink}
>
${GroupMatchingModeToLabel(GroupMatchingModeEnum.NameLink)}
</option>
<option
value=${GroupMatchingModeEnum.NameDeny}
?selected=${this.instance?.groupMatchingMode ===
GroupMatchingModeEnum.NameDeny}
>
${GroupMatchingModeToLabel(GroupMatchingModeEnum.NameDeny)}
</option>
</select>
</ak-form-element-horizontal>
${this.can(CapabilitiesEnum.CanSaveMedia)
? html`<ak-form-element-horizontal label=${msg("Icon")} name="icon">
<input type="file" value="" class="pf-c-form-control" />
@ -451,6 +510,43 @@ export class SAMLSourceForm extends WithCapabilitiesConfig(BaseSourceForm<SAMLSo
</ak-form-element-horizontal>
</div>
</ak-form-group>
<ak-form-group ?expanded=${true}>
<span slot="header"> ${msg("SAML Attribute mapping")} </span>
<div slot="body" class="pf-c-form">
<ak-form-element-horizontal
label=${msg("User Property Mappings")}
name="userPropertyMappings"
>
<ak-dual-select-dynamic-selected
.provider=${propertyMappingsProvider}
.selector=${makePropertyMappingsSelector(
this.instance?.userPropertyMappings,
)}
available-label="${msg("Available User Property Mappings")}"
selected-label="${msg("Selected User Property Mappings")}"
></ak-dual-select-dynamic-selected>
<p class="pf-c-form__helper-text">
${msg("Property mappings for user creation.")}
</p>
</ak-form-element-horizontal>
<ak-form-element-horizontal
label=${msg("Group Property Mappings")}
name="groupPropertyMappings"
>
<ak-dual-select-dynamic-selected
.provider=${propertyMappingsProvider}
.selector=${makePropertyMappingsSelector(
this.instance?.groupPropertyMappings,
)}
available-label="${msg("Available Group Property Mappings")}"
selected-label="${msg("Selected Group Property Mappings")}"
></ak-dual-select-dynamic-selected>
<p class="pf-c-form__helper-text">
${msg("Property mappings for group creation.")}
</p>
</ak-form-element-horizontal>
</div>
</ak-form-group>
<ak-form-group>
<span slot="header"> ${msg("Flow settings")} </span>
<div slot="body" class="pf-c-form">