Compare commits
92 Commits
benchmarks
...
version/20
Author | SHA1 | Date | |
---|---|---|---|
5afceaa55f | |||
72dc27f1c9 | |||
b5ffd16861 | |||
8af754e88c | |||
ade1f08c89 | |||
9240fa1037 | |||
1f5953b5b7 | |||
5befccc1fd | |||
ff193d809a | |||
23bbb6e5ef | |||
225d02d02d | |||
90fe1eda66 | |||
35ba88a203 | |||
8414a9dcad | |||
1d626f5b57 | |||
508dd0ac64 | |||
f4b82a8b09 | |||
2900f01976 | |||
0f6ece5eb7 | |||
b9936fe532 | |||
d0b3cc5916 | |||
e034f5e5dc | |||
9d6816bbc8 | |||
82d4ea9e8a | |||
c8a804f2a7 | |||
ca70c963e5 | |||
4c89d4a4a4 | |||
8a47acac3a | |||
4a3b22491c | |||
f991d656c7 | |||
e86aa11131 | |||
03725ae086 | |||
f2a37e8c7c | |||
e935690b1b | |||
02709e4ede | |||
f78adab9d1 | |||
61f3a72fd9 | |||
541becfe30 | |||
11ff7955f7 | |||
afa4234036 | |||
ca22a4deaf | |||
7b7a3d34ec | |||
b1ca579397 | |||
c8072579c8 | |||
378a701fb9 | |||
bba793d94c | |||
53f8699deb | |||
6f3dc2eafd | |||
567ed07fe8 | |||
2999e9d006 | |||
b32a228e3a | |||
5a2dfb23c6 | |||
8ebce479bd | |||
81589e835e | |||
22b1f39b91 | |||
c25e982f1f | |||
d5c09fae8a | |||
bf15e04053 | |||
0932622567 | |||
0a5b8bea5d | |||
64d4a19ccf | |||
82875cfc0e | |||
83776b9f08 | |||
a742331484 | |||
2e9df96a62 | |||
9f5d7089c3 | |||
ddc78cc297 | |||
cb9b3407d8 | |||
d7b872c1e0 | |||
c35217f581 | |||
3b73a2eb9d | |||
3b94ffa705 | |||
936102f6d9 | |||
8c687d81aa | |||
01d7263484 | |||
49ac0eb662 | |||
8935ca65a7 | |||
58a374d1f1 | |||
f409831921 | |||
951acb26dd | |||
2df0c95806 | |||
f8d1b7b9b7 | |||
e092aabb21 | |||
48c59a815d | |||
9f40716a87 | |||
39da241298 | |||
a71a87fa3e | |||
176fe2f6fc | |||
4544f475c9 | |||
5bbf59b2bd | |||
1b2f1db711 | |||
14fab991b4 |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2024.2.2
|
||||
current_version = 2024.4.3
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||
@ -21,6 +21,8 @@ optional_value = final
|
||||
|
||||
[bumpversion:file:schema.yml]
|
||||
|
||||
[bumpversion:file:blueprints/schema.json]
|
||||
|
||||
[bumpversion:file:authentik/__init__.py]
|
||||
|
||||
[bumpversion:file:internal/constants/constants.go]
|
||||
|
@ -12,7 +12,7 @@ should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower()
|
||||
branch_name = os.environ["GITHUB_REF"]
|
||||
if os.environ.get("GITHUB_HEAD_REF", "") != "":
|
||||
branch_name = os.environ["GITHUB_HEAD_REF"]
|
||||
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
|
||||
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-").replace("'", "-")
|
||||
|
||||
image_names = os.getenv("IMAGE_NAME").split(",")
|
||||
image_arch = os.getenv("IMAGE_ARCH") or None
|
||||
|
8
.github/workflows/ci-web.yml
vendored
8
.github/workflows/ci-web.yml
vendored
@ -34,6 +34,13 @@ jobs:
|
||||
- name: Eslint
|
||||
working-directory: ${{ matrix.project }}/
|
||||
run: npm run lint
|
||||
lint-lockfile:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- working-directory: web/
|
||||
run: |
|
||||
[ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ]
|
||||
lint-build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@ -95,6 +102,7 @@ jobs:
|
||||
run: npm run lit-analyse
|
||||
ci-web-mark:
|
||||
needs:
|
||||
- lint-lockfile
|
||||
- lint-eslint
|
||||
- lint-prettier
|
||||
- lint-lit-analyse
|
||||
|
8
.github/workflows/ci-website.yml
vendored
8
.github/workflows/ci-website.yml
vendored
@ -12,6 +12,13 @@ on:
|
||||
- version-*
|
||||
|
||||
jobs:
|
||||
lint-lockfile:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- working-directory: website/
|
||||
run: |
|
||||
[ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ]
|
||||
lint-prettier:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@ -62,6 +69,7 @@ jobs:
|
||||
run: npm run ${{ matrix.job }}
|
||||
ci-website-mark:
|
||||
needs:
|
||||
- lint-lockfile
|
||||
- lint-prettier
|
||||
- test
|
||||
- build
|
||||
|
20
SECURITY.md
20
SECURITY.md
@ -18,10 +18,10 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
|
||||
|
||||
(.x being the latest patch release for each version)
|
||||
|
||||
| Version | Supported |
|
||||
| --- | --- |
|
||||
| 2023.6.x | ✅ |
|
||||
| 2023.8.x | ✅ |
|
||||
| Version | Supported |
|
||||
| --------- | --------- |
|
||||
| 2023.10.x | ✅ |
|
||||
| 2024.2.x | ✅ |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
@ -31,12 +31,12 @@ To report a vulnerability, send an email to [security@goauthentik.io](mailto:se
|
||||
|
||||
authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories:
|
||||
|
||||
| Score | Severity |
|
||||
| --- | --- |
|
||||
| 0.0 | None |
|
||||
| 0.1 – 3.9 | Low |
|
||||
| 4.0 – 6.9 | Medium |
|
||||
| 7.0 – 8.9 | High |
|
||||
| Score | Severity |
|
||||
| ---------- | -------- |
|
||||
| 0.0 | None |
|
||||
| 0.1 – 3.9 | Low |
|
||||
| 4.0 – 6.9 | Medium |
|
||||
| 7.0 – 8.9 | High |
|
||||
| 9.0 – 10.0 | Critical |
|
||||
|
||||
## Disclosure process
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from os import environ
|
||||
|
||||
__version__ = "2024.2.2"
|
||||
__version__ = "2024.4.3"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
@ -0,0 +1,21 @@
|
||||
# Generated by Django 5.0.4 on 2024-04-18 18:56
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_brands", "0005_tenantuuid_to_branduuid"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddIndex(
|
||||
model_name="brand",
|
||||
index=models.Index(fields=["domain"], name="authentik_b_domain_b9b24a_idx"),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="brand",
|
||||
index=models.Index(fields=["default"], name="authentik_b_default_3ccf12_idx"),
|
||||
),
|
||||
]
|
@ -84,3 +84,7 @@ class Brand(SerializerModel):
|
||||
class Meta:
|
||||
verbose_name = _("Brand")
|
||||
verbose_name_plural = _("Brands")
|
||||
indexes = [
|
||||
models.Index(fields=["domain"]),
|
||||
models.Index(fields=["default"]),
|
||||
]
|
||||
|
@ -154,12 +154,18 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
pk = IntegerField(required=True)
|
||||
|
||||
queryset = Group.objects.all().select_related("parent").prefetch_related("users")
|
||||
queryset = Group.objects.none()
|
||||
serializer_class = GroupSerializer
|
||||
search_fields = ["name", "is_superuser"]
|
||||
filterset_class = GroupFilter
|
||||
ordering = ["name"]
|
||||
|
||||
def get_queryset(self):
|
||||
base_qs = Group.objects.all().select_related("parent").prefetch_related("roles")
|
||||
if self.serializer_class(context={"request": self.request})._should_include_users:
|
||||
base_qs = base_qs.prefetch_related("users")
|
||||
return base_qs
|
||||
|
||||
@extend_schema(
|
||||
parameters=[
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer
|
||||
from guardian.shortcuts import assign_perm, get_anonymous_user
|
||||
@ -27,7 +28,6 @@ from authentik.core.models import (
|
||||
TokenIntents,
|
||||
User,
|
||||
default_token_duration,
|
||||
token_expires_from_timedelta,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.utils import model_to_dict
|
||||
@ -45,6 +45,13 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
|
||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
|
||||
self.fields["key"] = CharField(required=False)
|
||||
|
||||
def validate_user(self, user: User):
|
||||
"""Ensure user of token cannot be changed"""
|
||||
if self.instance and self.instance.user_id:
|
||||
if user.pk != self.instance.user_id:
|
||||
raise ValidationError("User cannot be changed")
|
||||
return user
|
||||
|
||||
def validate(self, attrs: dict[Any, str]) -> dict[Any, str]:
|
||||
"""Ensure only API or App password tokens are created."""
|
||||
request: Request = self.context.get("request")
|
||||
@ -68,15 +75,17 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
|
||||
max_token_lifetime_dt = default_token_duration()
|
||||
if max_token_lifetime is not None:
|
||||
try:
|
||||
max_token_lifetime_dt = timedelta_from_string(max_token_lifetime)
|
||||
max_token_lifetime_dt = now() + timedelta_from_string(max_token_lifetime)
|
||||
except ValueError:
|
||||
max_token_lifetime_dt = default_token_duration()
|
||||
pass
|
||||
|
||||
if "expires" in attrs and attrs.get("expires") > token_expires_from_timedelta(
|
||||
max_token_lifetime_dt
|
||||
):
|
||||
if "expires" in attrs and attrs.get("expires") > max_token_lifetime_dt:
|
||||
raise ValidationError(
|
||||
{"expires": f"Token expires exceeds maximum lifetime ({max_token_lifetime})."}
|
||||
{
|
||||
"expires": (
|
||||
f"Token expires exceeds maximum lifetime ({max_token_lifetime_dt} UTC)."
|
||||
)
|
||||
}
|
||||
)
|
||||
elif attrs.get("intent") == TokenIntents.INTENT_API:
|
||||
# For API tokens, expires cannot be overridden
|
||||
|
@ -407,8 +407,11 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
search_fields = ["username", "name", "is_active", "email", "uuid"]
|
||||
filterset_class = UsersFilter
|
||||
|
||||
def get_queryset(self): # pragma: no cover
|
||||
return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
|
||||
def get_queryset(self):
|
||||
base_qs = User.objects.all().exclude_anonymous()
|
||||
if self.serializer_class(context={"request": self.request})._should_include_groups:
|
||||
base_qs = base_qs.prefetch_related("ak_groups")
|
||||
return base_qs
|
||||
|
||||
@extend_schema(
|
||||
parameters=[
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""authentik core models"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional, Self
|
||||
from uuid import uuid4
|
||||
@ -54,9 +54,6 @@ options.DEFAULT_NAMES = options.DEFAULT_NAMES + (
|
||||
# used_by API that allows models to specify if they shadow an object
|
||||
# for example the proxy provider which is built on top of an oauth provider
|
||||
"authentik_used_by_shadows",
|
||||
# List fields for which changes are not logged (due to them having dedicated objects)
|
||||
# for example user's password and last_login
|
||||
"authentik_signals_ignored_fields",
|
||||
)
|
||||
|
||||
|
||||
@ -71,11 +68,6 @@ def default_token_duration() -> datetime:
|
||||
return now() + timedelta_from_string(token_duration)
|
||||
|
||||
|
||||
def token_expires_from_timedelta(dt: timedelta) -> datetime:
|
||||
"""Return a `datetime.datetime` object with the duration of the Token"""
|
||||
return now() + dt
|
||||
|
||||
|
||||
def default_token_key() -> str:
|
||||
"""Default token key"""
|
||||
current_tenant = get_current_tenant()
|
||||
@ -335,14 +327,6 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
||||
models.Index(fields=["path"]),
|
||||
models.Index(fields=["type"]),
|
||||
]
|
||||
authentik_signals_ignored_fields = [
|
||||
# Logged by the events `password_set`
|
||||
# the `password_set` action/signal doesn't currently convey which user
|
||||
# initiated the password change, so for now we'll log two actions
|
||||
# ("password", "password_change_date"),
|
||||
# Logged by `login`
|
||||
("last_login",),
|
||||
]
|
||||
|
||||
|
||||
class Provider(SerializerModel):
|
||||
@ -648,7 +632,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel):
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"User-source connection (user={self.user.username}, source={self.source.slug})"
|
||||
return f"User-source connection (user={self.user_id}, source={self.source_id})"
|
||||
|
||||
class Meta:
|
||||
unique_together = (("user", "source"),)
|
||||
|
@ -13,7 +13,7 @@ from django.utils.translation import gettext as _
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
|
||||
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage
|
||||
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage
|
||||
@ -100,8 +100,6 @@ class SourceFlowManager:
|
||||
if self.request.user.is_authenticated:
|
||||
new_connection.user = self.request.user
|
||||
new_connection = self.update_connection(new_connection, **kwargs)
|
||||
|
||||
new_connection.save()
|
||||
return Action.LINK, new_connection
|
||||
|
||||
existing_connections = self.connection_type.objects.filter(
|
||||
@ -148,7 +146,6 @@ class SourceFlowManager:
|
||||
]:
|
||||
new_connection.user = user
|
||||
new_connection = self.update_connection(new_connection, **kwargs)
|
||||
new_connection.save()
|
||||
return Action.LINK, new_connection
|
||||
if self.source.user_matching_mode in [
|
||||
SourceUserMatchingModes.EMAIL_DENY,
|
||||
@ -209,13 +206,9 @@ class SourceFlowManager:
|
||||
|
||||
def get_stages_to_append(self, flow: Flow) -> list[Stage]:
|
||||
"""Hook to override stages which are appended to the flow"""
|
||||
if not self.source.enrollment_flow:
|
||||
return []
|
||||
if flow.slug == self.source.enrollment_flow.slug:
|
||||
return [
|
||||
in_memory_stage(PostUserEnrollmentStage),
|
||||
]
|
||||
return []
|
||||
return [
|
||||
in_memory_stage(PostSourceStage),
|
||||
]
|
||||
|
||||
def _prepare_flow(
|
||||
self,
|
||||
@ -269,6 +262,9 @@ class SourceFlowManager:
|
||||
)
|
||||
# We run the Flow planner here so we can pass the Pending user in the context
|
||||
planner = FlowPlanner(flow)
|
||||
# We append some stages so the initial flow we get might be empty
|
||||
planner.allow_empty_flows = True
|
||||
planner.use_cache = False
|
||||
plan = planner.plan(self.request, kwargs)
|
||||
for stage in self.get_stages_to_append(flow):
|
||||
plan.append_stage(stage)
|
||||
@ -327,7 +323,7 @@ class SourceFlowManager:
|
||||
reverse(
|
||||
"authentik_core:if-user",
|
||||
)
|
||||
+ f"#/settings;page-{self.source.slug}"
|
||||
+ "#/settings;page-sources"
|
||||
)
|
||||
|
||||
def handle_enroll(
|
||||
|
@ -10,7 +10,7 @@ from authentik.flows.stage import StageView
|
||||
PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection"
|
||||
|
||||
|
||||
class PostUserEnrollmentStage(StageView):
|
||||
class PostSourceStage(StageView):
|
||||
"""Dynamically injected stage which saves the Connection after
|
||||
the user has been enrolled."""
|
||||
|
||||
@ -21,10 +21,12 @@ class PostUserEnrollmentStage(StageView):
|
||||
]
|
||||
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
||||
connection.user = user
|
||||
linked = connection.pk is None
|
||||
connection.save()
|
||||
Event.new(
|
||||
EventAction.SOURCE_LINKED,
|
||||
message="Linked Source",
|
||||
source=connection.source,
|
||||
).from_http(self.request)
|
||||
if linked:
|
||||
Event.new(
|
||||
EventAction.SOURCE_LINKED,
|
||||
message="Linked Source",
|
||||
source=connection.source,
|
||||
).from_http(self.request)
|
||||
return self.executor.stage_ok()
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.conf import ImproperlyConfigured
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.contrib.sessions.backends.db import SessionStore as DBSessionStore
|
||||
from django.core.cache import cache
|
||||
from django.utils.timezone import now
|
||||
from structlog.stdlib import get_logger
|
||||
@ -15,6 +17,7 @@ from authentik.core.models import (
|
||||
User,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -39,16 +42,31 @@ def clean_expired_models(self: SystemTask):
|
||||
amount = 0
|
||||
|
||||
for session in AuthenticatedSession.objects.all():
|
||||
cache_key = f"{KEY_PREFIX}{session.session_key}"
|
||||
value = None
|
||||
try:
|
||||
value = cache.get(cache_key)
|
||||
match CONFIG.get("session_storage", "cache"):
|
||||
case "cache":
|
||||
cache_key = f"{KEY_PREFIX}{session.session_key}"
|
||||
value = None
|
||||
try:
|
||||
value = cache.get(cache_key)
|
||||
|
||||
except Exception as exc:
|
||||
LOGGER.debug("Failed to get session from cache", exc=exc)
|
||||
if not value:
|
||||
session.delete()
|
||||
amount += 1
|
||||
except Exception as exc:
|
||||
LOGGER.debug("Failed to get session from cache", exc=exc)
|
||||
if not value:
|
||||
session.delete()
|
||||
amount += 1
|
||||
case "db":
|
||||
if not (
|
||||
DBSessionStore.get_model_class()
|
||||
.objects.filter(session_key=session.session_key, expire_date__gt=now())
|
||||
.exists()
|
||||
):
|
||||
session.delete()
|
||||
amount += 1
|
||||
case _:
|
||||
# Should never happen, as we check for other values in authentik/root/settings.py
|
||||
raise ImproperlyConfigured(
|
||||
"Invalid session_storage setting, allowed values are db and cache"
|
||||
)
|
||||
LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
|
||||
|
||||
messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")
|
||||
|
@ -5,7 +5,7 @@ from guardian.shortcuts import assign_perm
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
@ -16,6 +16,13 @@ class TestGroupsAPI(APITestCase):
|
||||
self.login_user = create_test_user()
|
||||
self.user = User.objects.create(username="test-user")
|
||||
|
||||
def test_list_with_users(self):
|
||||
"""Test listing with users"""
|
||||
admin = create_test_admin_user()
|
||||
self.client.force_login(admin)
|
||||
response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_add_user(self):
|
||||
"""Test add_user"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
|
@ -2,11 +2,15 @@
|
||||
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
from guardian.utils import get_anonymous_user
|
||||
|
||||
from authentik.core.models import SourceUserMatchingModes, User
|
||||
from authentik.core.sources.flow_manager import Action
|
||||
from authentik.core.sources.stage import PostSourceStage
|
||||
from authentik.core.tests.utils import create_test_flow
|
||||
from authentik.flows.planner import FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import get_request
|
||||
from authentik.policies.denied import AccessDeniedResponse
|
||||
@ -21,42 +25,62 @@ class TestSourceFlowManager(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.source: OAuthSource = OAuthSource.objects.create(name="test")
|
||||
self.authentication_flow = create_test_flow()
|
||||
self.enrollment_flow = create_test_flow()
|
||||
self.source: OAuthSource = OAuthSource.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
authentication_flow=self.authentication_flow,
|
||||
enrollment_flow=self.enrollment_flow,
|
||||
)
|
||||
self.identifier = generate_id()
|
||||
|
||||
def test_unauthenticated_enroll(self):
|
||||
"""Test un-authenticated user enrolling"""
|
||||
flow_manager = OAuthSourceFlowManager(
|
||||
self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
|
||||
)
|
||||
request = get_request("/", user=AnonymousUser())
|
||||
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
|
||||
action, _ = flow_manager.get_action()
|
||||
self.assertEqual(action, Action.ENROLL)
|
||||
flow_manager.get_flow()
|
||||
response = flow_manager.get_flow()
|
||||
self.assertEqual(response.status_code, 302)
|
||||
flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage)
|
||||
|
||||
def test_unauthenticated_auth(self):
|
||||
"""Test un-authenticated user authenticating"""
|
||||
UserOAuthSourceConnection.objects.create(
|
||||
user=get_anonymous_user(), source=self.source, identifier=self.identifier
|
||||
)
|
||||
|
||||
flow_manager = OAuthSourceFlowManager(
|
||||
self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
|
||||
)
|
||||
request = get_request("/", user=AnonymousUser())
|
||||
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
|
||||
action, _ = flow_manager.get_action()
|
||||
self.assertEqual(action, Action.AUTH)
|
||||
flow_manager.get_flow()
|
||||
response = flow_manager.get_flow()
|
||||
self.assertEqual(response.status_code, 302)
|
||||
flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage)
|
||||
|
||||
def test_authenticated_link(self):
|
||||
"""Test authenticated user linking"""
|
||||
UserOAuthSourceConnection.objects.create(
|
||||
user=get_anonymous_user(), source=self.source, identifier=self.identifier
|
||||
)
|
||||
user = User.objects.create(username="foo", email="foo@bar.baz")
|
||||
flow_manager = OAuthSourceFlowManager(
|
||||
self.source, get_request("/", user=user), self.identifier, {}
|
||||
)
|
||||
action, _ = flow_manager.get_action()
|
||||
request = get_request("/", user=user)
|
||||
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
|
||||
action, connection = flow_manager.get_action()
|
||||
self.assertEqual(action, Action.LINK)
|
||||
self.assertIsNone(connection.pk)
|
||||
response = flow_manager.get_flow()
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(
|
||||
response.url,
|
||||
reverse("authentik_core:if-user") + "#/settings;page-sources",
|
||||
)
|
||||
|
||||
def test_unauthenticated_link(self):
|
||||
"""Test un-authenticated user linking"""
|
||||
flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {})
|
||||
action, connection = flow_manager.get_action()
|
||||
self.assertEqual(action, Action.LINK)
|
||||
self.assertIsNone(connection.pk)
|
||||
flow_manager.get_flow()
|
||||
|
||||
def test_unauthenticated_enroll_email(self):
|
||||
|
@ -13,9 +13,8 @@ from authentik.core.models import (
|
||||
USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME,
|
||||
Token,
|
||||
TokenIntents,
|
||||
User,
|
||||
)
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
@ -24,7 +23,7 @@ class TestTokenAPI(APITestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.user = User.objects.create(username="testuser")
|
||||
self.user = create_test_user()
|
||||
self.admin = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
@ -154,6 +153,24 @@ class TestTokenAPI(APITestCase):
|
||||
self.assertEqual(token.expiring, True)
|
||||
self.assertNotEqual(token.expires.timestamp(), expires.timestamp())
|
||||
|
||||
def test_token_change_user(self):
|
||||
"""Test creating a token and then changing the user"""
|
||||
ident = generate_id()
|
||||
response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident})
|
||||
self.assertEqual(response.status_code, 201)
|
||||
token = Token.objects.get(identifier=ident)
|
||||
self.assertEqual(token.user, self.user)
|
||||
self.assertEqual(token.intent, TokenIntents.INTENT_API)
|
||||
self.assertEqual(token.expiring, True)
|
||||
self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token))
|
||||
response = self.client.put(
|
||||
reverse("authentik_api:token-detail", kwargs={"identifier": ident}),
|
||||
data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
token.refresh_from_db()
|
||||
self.assertEqual(token.user, self.user)
|
||||
|
||||
def test_list(self):
|
||||
"""Test Token List (Test normal authentication)"""
|
||||
Token.objects.all().delete()
|
||||
|
@ -41,6 +41,12 @@ class TestUsersAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_list_with_groups(self):
|
||||
"""Test listing with groups"""
|
||||
self.client.force_login(self.admin)
|
||||
response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_metrics(self):
|
||||
"""Test user's metrics"""
|
||||
self.client.force_login(self.admin)
|
||||
|
@ -8,7 +8,6 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
|
||||
@ -25,7 +24,6 @@ class TestUsersAvatars(APITestCase):
|
||||
tenant.avatars = mode
|
||||
tenant.save()
|
||||
|
||||
@CONFIG.patch("avatars", "none")
|
||||
def test_avatars_none(self):
|
||||
"""Test avatars none"""
|
||||
self.set_avatar_mode("none")
|
||||
|
@ -4,7 +4,7 @@ from django.utils.text import slugify
|
||||
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.crypto.builder import CertificateBuilder
|
||||
from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.flows.models import Flow, FlowDesignation
|
||||
from authentik.lib.generators import generate_id
|
||||
@ -50,12 +50,10 @@ def create_test_brand(**kwargs) -> Brand:
|
||||
return Brand.objects.create(domain=uid, default=True, **kwargs)
|
||||
|
||||
|
||||
def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair:
|
||||
def create_test_cert(alg=PrivateKeyAlg.RSA) -> CertificateKeyPair:
|
||||
"""Generate a certificate for testing"""
|
||||
builder = CertificateBuilder(
|
||||
name=f"{generate_id()}.self-signed.goauthentik.io",
|
||||
use_ec_private_key=use_ec_private_key,
|
||||
)
|
||||
builder = CertificateBuilder(f"{generate_id()}.self-signed.goauthentik.io")
|
||||
builder.alg = alg
|
||||
builder.build(
|
||||
subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"],
|
||||
validity_days=360,
|
||||
|
@ -14,7 +14,13 @@ from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField
|
||||
from rest_framework.fields import (
|
||||
CharField,
|
||||
ChoiceField,
|
||||
DateTimeField,
|
||||
IntegerField,
|
||||
SerializerMethodField,
|
||||
)
|
||||
from rest_framework.filters import OrderingFilter, SearchFilter
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
@ -26,7 +32,7 @@ from authentik.api.authorization import SecretKeyFilter
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.crypto.apps import MANAGED_KEY
|
||||
from authentik.crypto.builder import CertificateBuilder
|
||||
from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.rbac.decorators import permission_required
|
||||
@ -178,6 +184,7 @@ class CertificateGenerationSerializer(PassiveSerializer):
|
||||
common_name = CharField()
|
||||
subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name"))
|
||||
validity_days = IntegerField(initial=365)
|
||||
alg = ChoiceField(default=PrivateKeyAlg.RSA, choices=PrivateKeyAlg.choices)
|
||||
|
||||
|
||||
class CertificateKeyPairFilter(FilterSet):
|
||||
@ -240,6 +247,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
raw_san = data.validated_data.get("subject_alt_name", "")
|
||||
sans = raw_san.split(",") if raw_san != "" else []
|
||||
builder = CertificateBuilder(data.validated_data["common_name"])
|
||||
builder.alg = data.validated_data["alg"]
|
||||
builder.build(
|
||||
subject_alt_names=sans,
|
||||
validity_days=int(data.validated_data["validity_days"]),
|
||||
|
@ -9,20 +9,28 @@ from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
|
||||
from cryptography.x509.oid import NameOID
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from authentik import __version__
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
|
||||
|
||||
class PrivateKeyAlg(models.TextChoices):
|
||||
"""Algorithm to create private key with"""
|
||||
|
||||
RSA = "rsa", _("rsa")
|
||||
ECDSA = "ecdsa", _("ecdsa")
|
||||
|
||||
|
||||
class CertificateBuilder:
|
||||
"""Build self-signed certificates"""
|
||||
|
||||
common_name: str
|
||||
alg: PrivateKeyAlg
|
||||
|
||||
_use_ec_private_key: bool
|
||||
|
||||
def __init__(self, name: str, use_ec_private_key=False):
|
||||
self._use_ec_private_key = use_ec_private_key
|
||||
def __init__(self, name: str):
|
||||
self.alg = PrivateKeyAlg.RSA
|
||||
self.__public_key = None
|
||||
self.__private_key = None
|
||||
self.__builder = None
|
||||
@ -42,11 +50,13 @@ class CertificateBuilder:
|
||||
|
||||
def generate_private_key(self) -> PrivateKeyTypes:
|
||||
"""Generate private key"""
|
||||
if self._use_ec_private_key:
|
||||
if self.alg == PrivateKeyAlg.ECDSA:
|
||||
return ec.generate_private_key(curve=ec.SECP256R1())
|
||||
return rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=4096, backend=default_backend()
|
||||
)
|
||||
if self.alg == PrivateKeyAlg.RSA:
|
||||
return rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=4096, backend=default_backend()
|
||||
)
|
||||
raise ValueError(f"Invalid alg: {self.alg}")
|
||||
|
||||
def build(
|
||||
self,
|
||||
|
@ -2,11 +2,12 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from django.apps.registry import apps
|
||||
from django.core.files import File
|
||||
from django.db import connection
|
||||
from django.db.models import Model
|
||||
from django.db.models import ManyToManyRel, Model
|
||||
from django.db.models.expressions import BaseExpression, Combinable
|
||||
from django.db.models.signals import post_init
|
||||
from django.http import HttpRequest
|
||||
@ -44,7 +45,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
post_init.disconnect(dispatch_uid=request.request_id)
|
||||
|
||||
def serialize_simple(self, model: Model) -> dict:
|
||||
"""Serialize a model in a very simple way. No ForeginKeys or other relationships are
|
||||
"""Serialize a model in a very simple way. No ForeignKeys or other relationships are
|
||||
resolved"""
|
||||
data = {}
|
||||
deferred_fields = model.get_deferred_fields()
|
||||
@ -70,6 +71,9 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
for key, value in before.items():
|
||||
if after.get(key) != value:
|
||||
diff[key] = {"previous_value": value, "new_value": after.get(key)}
|
||||
for key, value in after.items():
|
||||
if key not in before and key not in diff and before.get(key) != value:
|
||||
diff[key] = {"previous_value": before.get(key), "new_value": value}
|
||||
return sanitize_item(diff)
|
||||
|
||||
def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
|
||||
@ -98,13 +102,37 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
thread_kwargs = {}
|
||||
if hasattr(instance, "_previous_state") or created:
|
||||
prev_state = getattr(instance, "_previous_state", {})
|
||||
if created:
|
||||
prev_state = {}
|
||||
# Get current state
|
||||
new_state = self.serialize_simple(instance)
|
||||
diff = self.diff(prev_state, new_state)
|
||||
thread_kwargs["diff"] = diff
|
||||
if not created:
|
||||
ignored_field_sets = getattr(instance._meta, "authentik_signals_ignored_fields", [])
|
||||
for field_set in ignored_field_sets:
|
||||
if set(diff.keys()) == set(field_set):
|
||||
return None
|
||||
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
|
||||
|
||||
def m2m_changed_handler( # noqa: PLR0913
|
||||
self,
|
||||
request: HttpRequest,
|
||||
sender,
|
||||
instance: Model,
|
||||
action: str,
|
||||
pk_set: set[Any],
|
||||
thread_kwargs: dict | None = None,
|
||||
**_,
|
||||
):
|
||||
thread_kwargs = {}
|
||||
m2m_field = None
|
||||
# For the audit log we don't care about `pre_` or `post_` so we trim that part off
|
||||
_, _, action_direction = action.partition("_")
|
||||
# resolve the "through" model to an actual field
|
||||
for field in instance._meta.get_fields():
|
||||
if not isinstance(field, ManyToManyRel):
|
||||
continue
|
||||
if field.through == sender:
|
||||
m2m_field = field
|
||||
if m2m_field:
|
||||
# If we're clearing we just set the "flag" to True
|
||||
if action_direction == "clear":
|
||||
pk_set = True
|
||||
thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
|
||||
return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
|
||||
|
@ -1,9 +1,22 @@
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestEnterpriseAudit(TestCase):
|
||||
class TestEnterpriseAudit(APITestCase):
|
||||
"""Test audit middleware"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
|
||||
def test_import(self):
|
||||
"""Ensure middleware is imported when app.ready is called"""
|
||||
@ -16,3 +29,182 @@ class TestEnterpriseAudit(TestCase):
|
||||
self.assertIn(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_create(self):
|
||||
"""Test create audit log"""
|
||||
self.client.force_login(self.user)
|
||||
username = generate_id()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-list"),
|
||||
data={"name": generate_id(), "username": username, "groups": [], "path": "foo"},
|
||||
)
|
||||
user = User.objects.get(username=username)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED,
|
||||
context__model__model_name="user",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=user.pk,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{
|
||||
"name": {
|
||||
"new_value": user.name,
|
||||
"previous_value": None,
|
||||
},
|
||||
"path": {"new_value": "foo", "previous_value": None},
|
||||
"type": {"new_value": "internal", "previous_value": None},
|
||||
"uuid": {
|
||||
"new_value": user.uuid.hex,
|
||||
"previous_value": None,
|
||||
},
|
||||
"email": {"new_value": "", "previous_value": None},
|
||||
"username": {
|
||||
"new_value": user.username,
|
||||
"previous_value": None,
|
||||
},
|
||||
"is_active": {"new_value": True, "previous_value": None},
|
||||
"attributes": {"new_value": {}, "previous_value": None},
|
||||
"date_joined": {
|
||||
"new_value": sanitize_item(user.date_joined),
|
||||
"previous_value": None,
|
||||
},
|
||||
"first_name": {"new_value": "", "previous_value": None},
|
||||
"id": {"new_value": user.pk, "previous_value": None},
|
||||
"last_name": {"new_value": "", "previous_value": None},
|
||||
"password": {"new_value": "********************", "previous_value": None},
|
||||
"password_change_date": {
|
||||
"new_value": sanitize_item(user.password_change_date),
|
||||
"previous_value": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_update(self):
|
||||
"""Test update audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
current_name = user.name
|
||||
new_name = generate_id()
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
|
||||
data={"name": new_name},
|
||||
)
|
||||
user.refresh_from_db()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_UPDATED,
|
||||
context__model__model_name="user",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=user.pk,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{
|
||||
"name": {
|
||||
"new_value": new_name,
|
||||
"previous_value": current_name,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_delete(self):
|
||||
"""Test delete audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
response = self.client.delete(
|
||||
reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_DELETED,
|
||||
context__model__model_name="user",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=user.pk,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertNotIn("diff", event.context)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_m2m_add(self):
|
||||
"""Test m2m add audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:group-add-user", kwargs={"pk": group.group_uuid}),
|
||||
data={
|
||||
"pk": user.pk,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_UPDATED,
|
||||
context__model__model_name="group",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=group.pk.hex,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{"users": {"add": [user.pk]}},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_m2m_remove(self):
|
||||
"""Test m2m remove audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:group-remove-user", kwargs={"pk": group.group_uuid}),
|
||||
data={
|
||||
"pk": user.pk,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_UPDATED,
|
||||
context__model__model_name="group",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=group.pk.hex,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{"users": {"remove": [user.pk]}},
|
||||
)
|
||||
|
@ -201,10 +201,7 @@ class ConnectionToken(ExpiringModel):
|
||||
return settings
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"RAC Connection token {self.session.user} to "
|
||||
f"{self.endpoint.provider.name}/{self.endpoint.name}"
|
||||
)
|
||||
return f"RAC Connection token {self.session_id} to {self.provider_id}/{self.endpoint_id}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Connection token")
|
||||
|
@ -116,12 +116,12 @@ class AuditMiddleware:
|
||||
return user
|
||||
user = getattr(request, "user", self.anonymous_user)
|
||||
if not user.is_authenticated:
|
||||
self._ensure_fallback_user()
|
||||
return self.anonymous_user
|
||||
return user
|
||||
|
||||
def connect(self, request: HttpRequest):
|
||||
"""Connect signal for automatic logging"""
|
||||
self._ensure_fallback_user()
|
||||
if not hasattr(request, "request_id"):
|
||||
return
|
||||
post_save.connect(
|
||||
@ -214,7 +214,15 @@ class AuditMiddleware:
|
||||
model=model_to_dict(instance),
|
||||
).run()
|
||||
|
||||
def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_):
|
||||
def m2m_changed_handler(
|
||||
self,
|
||||
request: HttpRequest,
|
||||
sender,
|
||||
instance: Model,
|
||||
action: str,
|
||||
thread_kwargs: dict | None = None,
|
||||
**_,
|
||||
):
|
||||
"""Signal handler for all object's m2m_changed"""
|
||||
if action not in ["pre_add", "pre_remove", "post_clear"]:
|
||||
return
|
||||
@ -229,4 +237,5 @@ class AuditMiddleware:
|
||||
request,
|
||||
user=user,
|
||||
model=model_to_dict(instance),
|
||||
**thread_kwargs,
|
||||
).run()
|
||||
|
@ -556,7 +556,7 @@ class Notification(SerializerModel):
|
||||
if len(self.body) > NOTIFICATION_SUMMARY_LENGTH
|
||||
else self.body
|
||||
)
|
||||
return f"Notification for user {self.user}: {body_trunc}"
|
||||
return f"Notification for user {self.user_id}: {body_trunc}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Notification")
|
||||
|
@ -119,7 +119,7 @@ class SystemTask(TenantTask):
|
||||
"task_call_kwargs": sanitize_item(kwargs),
|
||||
"status": self._status,
|
||||
"messages": sanitize_item(self._messages),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours + 3),
|
||||
"expiring": True,
|
||||
},
|
||||
)
|
||||
|
35
authentik/events/tests/test_models.py
Normal file
35
authentik/events/tests/test_models.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""authentik event models tests"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.db.models import Model
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.core.models import default_token_key
|
||||
from authentik.lib.utils.reflection import get_apps
|
||||
|
||||
|
||||
class TestModels(TestCase):
|
||||
"""Test Models"""
|
||||
|
||||
|
||||
def model_tester_factory(test_model: type[Model]) -> Callable:
|
||||
"""Test models' __str__ and __repr__"""
|
||||
|
||||
def tester(self: TestModels):
|
||||
allowed = 0
|
||||
# Token-like objects need to lookup the current tenant to get the default token length
|
||||
for field in test_model._meta.fields:
|
||||
if field.default == default_token_key:
|
||||
allowed += 1
|
||||
with self.assertNumQueries(allowed):
|
||||
str(test_model())
|
||||
with self.assertNumQueries(allowed):
|
||||
repr(test_model())
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
for app in get_apps():
|
||||
for model in app.get_models():
|
||||
setattr(TestModels, f"test_{app.label}_{model.__name__}", model_tester_factory(model))
|
@ -278,7 +278,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
|
||||
},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def execute(self, request: Request, _slug: str):
|
||||
def execute(self, request: Request, slug: str):
|
||||
"""Execute flow for current user"""
|
||||
# Because we pre-plan the flow here, and not in the planner, we need to manually clear
|
||||
# the history of the inspector
|
||||
|
@ -203,7 +203,8 @@ class FlowPlanner:
|
||||
"f(plan): building plan",
|
||||
)
|
||||
plan = self._build_plan(user, request, default_context)
|
||||
cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT)
|
||||
if self.use_cache:
|
||||
cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT)
|
||||
if not plan.bindings and not self.allow_empty_flows:
|
||||
raise EmptyFlowException()
|
||||
return plan
|
||||
|
@ -6,6 +6,7 @@ from rest_framework.test import APITestCase
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.flows.api.stages import StageSerializer, StageViewSet
|
||||
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.stages.dummy.models import DummyStage
|
||||
@ -101,3 +102,21 @@ class TestFlowsAPI(APITestCase):
|
||||
reverse("authentik_api:stage-types"),
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_execute(self):
|
||||
"""Test execute endpoint"""
|
||||
user = create_test_admin_user()
|
||||
self.client.force_login(user)
|
||||
|
||||
flow = Flow.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
designation=FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
FlowStageBinding.objects.create(
|
||||
target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
|
||||
)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-execute", kwargs={"slug": flow.slug})
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
@ -14,7 +14,7 @@ from pathlib import Path
|
||||
from sys import argv, stderr
|
||||
from time import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import quote_plus, urlparse
|
||||
|
||||
import yaml
|
||||
from django.conf import ImproperlyConfigured
|
||||
@ -331,6 +331,26 @@ class ConfigLoader:
|
||||
CONFIG = ConfigLoader()
|
||||
|
||||
|
||||
def redis_url(db: int) -> str:
|
||||
"""Helper to create a Redis URL for a specific database"""
|
||||
_redis_protocol_prefix = "redis://"
|
||||
_redis_tls_requirements = ""
|
||||
if CONFIG.get_bool("redis.tls", False):
|
||||
_redis_protocol_prefix = "rediss://"
|
||||
_redis_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}"
|
||||
if _redis_ca := CONFIG.get("redis.tls_ca_cert", None):
|
||||
_redis_tls_requirements += f"&ssl_ca_certs={_redis_ca}"
|
||||
_redis_url = (
|
||||
f"{_redis_protocol_prefix}"
|
||||
f"{quote_plus(CONFIG.get('redis.username'))}:"
|
||||
f"{quote_plus(CONFIG.get('redis.password'))}@"
|
||||
f"{quote_plus(CONFIG.get('redis.host'))}:"
|
||||
f"{CONFIG.get_int('redis.port')}"
|
||||
f"/{db}{_redis_tls_requirements}"
|
||||
)
|
||||
return _redis_url
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(argv) < 2: # noqa: PLR2004
|
||||
print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
|
||||
|
@ -35,6 +35,7 @@ redis:
|
||||
password: ""
|
||||
tls: false
|
||||
tls_reqs: "none"
|
||||
tls_ca_cert: null
|
||||
|
||||
# broker:
|
||||
# url: ""
|
||||
@ -52,12 +53,15 @@ cache:
|
||||
|
||||
# result_backend:
|
||||
# url: ""
|
||||
# transport_options: ""
|
||||
|
||||
debug: false
|
||||
remote_debug: false
|
||||
|
||||
log_level: info
|
||||
|
||||
session_storage: cache
|
||||
|
||||
error_reporting:
|
||||
enabled: false
|
||||
sentry_dsn: https://151ba72610234c4c97c5bcff4e1cffd8@authentik.error-reporting.a7k.io/4504163677503489
|
||||
|
@ -326,7 +326,7 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel):
|
||||
verbose_name_plural = _("Authorization Codes")
|
||||
|
||||
def __str__(self):
|
||||
return f"Authorization code for {self.provider} for user {self.user}"
|
||||
return f"Authorization code for {self.provider_id} for user {self.user_id}"
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
@ -356,7 +356,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
||||
verbose_name_plural = _("OAuth2 Access Tokens")
|
||||
|
||||
def __str__(self):
|
||||
return f"Access Token for {self.provider} for user {self.user}"
|
||||
return f"Access Token for {self.provider_id} for user {self.user_id}"
|
||||
|
||||
@property
|
||||
def id_token(self) -> IDToken:
|
||||
@ -399,7 +399,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
||||
verbose_name_plural = _("OAuth2 Refresh Tokens")
|
||||
|
||||
def __str__(self):
|
||||
return f"Refresh Token for {self.provider} for user {self.user}"
|
||||
return f"Refresh Token for {self.provider_id} for user {self.user_id}"
|
||||
|
||||
@property
|
||||
def id_token(self) -> IDToken:
|
||||
@ -443,4 +443,4 @@ class DeviceToken(ExpiringModel):
|
||||
verbose_name_plural = _("Device Tokens")
|
||||
|
||||
def __str__(self):
|
||||
return f"Device Token for {self.provider}"
|
||||
return f"Device Token for {self.provider_id}"
|
||||
|
@ -4,9 +4,10 @@ from urllib.parse import urlencode
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.models import Application, Group
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
@ -77,3 +78,23 @@ class TesOAuth2DeviceInit(OAuthTestCase):
|
||||
+ "?"
|
||||
+ urlencode({QS_KEY_CODE: token.user_code}),
|
||||
)
|
||||
|
||||
def test_device_init_denied(self):
|
||||
"""Test device init"""
|
||||
group = Group.objects.create(name="foo")
|
||||
PolicyBinding.objects.create(
|
||||
group=group,
|
||||
target=self.application,
|
||||
order=0,
|
||||
)
|
||||
token = DeviceToken.objects.create(
|
||||
user_code="foo",
|
||||
provider=self.provider,
|
||||
)
|
||||
res = self.client.get(
|
||||
reverse("authentik_providers_oauth2_root:device-login")
|
||||
+ "?"
|
||||
+ urlencode({QS_KEY_CODE: token.user_code})
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertIn(b"Permission denied", res.content)
|
||||
|
@ -10,6 +10,7 @@ from jwt import PyJWKSet
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||
from authentik.crypto.builder import PrivateKeyAlg
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
@ -82,7 +83,7 @@ class TestJWKS(OAuthTestCase):
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://local.invalid",
|
||||
signing_key=create_test_cert(use_ec_private_key=True),
|
||||
signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
|
||||
)
|
||||
app = Application.objects.create(name="test", slug="test", provider=provider)
|
||||
response = self.client.get(
|
||||
|
@ -11,10 +11,11 @@ from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework.throttling import AnonRateThrottle
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -37,7 +38,9 @@ class DeviceView(View):
|
||||
).first()
|
||||
if not provider:
|
||||
return HttpResponseBadRequest()
|
||||
if not get_application(provider):
|
||||
try:
|
||||
_ = provider.application
|
||||
except Application.DoesNotExist:
|
||||
return HttpResponseBadRequest()
|
||||
self.provider = provider
|
||||
self.client_id = client_id
|
||||
|
@ -1,8 +1,9 @@
|
||||
"""Device flow views"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.translation import gettext as _
|
||||
from django.views import View
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
from structlog.stdlib import get_logger
|
||||
@ -16,7 +17,8 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO,
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.utils.urls import redirect_with_qs
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.policies.views import PolicyAccessView
|
||||
from authentik.providers.oauth2.models import DeviceToken
|
||||
from authentik.providers.oauth2.views.device_finish import (
|
||||
PLAN_CONTEXT_DEVICE,
|
||||
OAuthDeviceCodeFinishStage,
|
||||
@ -31,60 +33,52 @@ LOGGER = get_logger()
|
||||
QS_KEY_CODE = "code" # nosec
|
||||
|
||||
|
||||
def get_application(provider: OAuth2Provider) -> Application | None:
|
||||
"""Get application from provider"""
|
||||
try:
|
||||
app = provider.application
|
||||
if not app:
|
||||
class CodeValidatorView(PolicyAccessView):
|
||||
"""Helper to validate frontside token"""
|
||||
|
||||
def __init__(self, code: str, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.code = code
|
||||
|
||||
def resolve_provider_application(self):
|
||||
self.token = DeviceToken.objects.filter(user_code=self.code).first()
|
||||
if not self.token:
|
||||
raise Application.DoesNotExist
|
||||
self.provider = self.token.provider
|
||||
self.application = self.token.provider.application
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs):
|
||||
scope_descriptions = UserInfoView().get_scope_descriptions(self.token.scope, self.provider)
|
||||
planner = FlowPlanner(self.provider.authorization_flow)
|
||||
planner.allow_empty_flows = True
|
||||
planner.use_cache = False
|
||||
try:
|
||||
plan = planner.plan(
|
||||
request,
|
||||
{
|
||||
PLAN_CONTEXT_SSO: True,
|
||||
PLAN_CONTEXT_APPLICATION: self.application,
|
||||
# OAuth2 related params
|
||||
PLAN_CONTEXT_DEVICE: self.token,
|
||||
# Consent related params
|
||||
PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
|
||||
% {"application": self.application.name},
|
||||
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
|
||||
},
|
||||
)
|
||||
except FlowNonApplicableException:
|
||||
LOGGER.warning("Flow not applicable to user")
|
||||
return None
|
||||
return app
|
||||
except Application.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def validate_code(code: int, request: HttpRequest) -> HttpResponse | None:
|
||||
"""Validate user token"""
|
||||
token = DeviceToken.objects.filter(
|
||||
user_code=code,
|
||||
).first()
|
||||
if not token:
|
||||
return None
|
||||
|
||||
app = get_application(token.provider)
|
||||
if not app:
|
||||
return None
|
||||
|
||||
scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider)
|
||||
planner = FlowPlanner(token.provider.authorization_flow)
|
||||
planner.allow_empty_flows = True
|
||||
planner.use_cache = False
|
||||
try:
|
||||
plan = planner.plan(
|
||||
request,
|
||||
{
|
||||
PLAN_CONTEXT_SSO: True,
|
||||
PLAN_CONTEXT_APPLICATION: app,
|
||||
# OAuth2 related params
|
||||
PLAN_CONTEXT_DEVICE: token,
|
||||
# Consent related params
|
||||
PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
|
||||
% {"application": app.name},
|
||||
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
|
||||
},
|
||||
plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
|
||||
request.session[SESSION_KEY_PLAN] = plan
|
||||
return redirect_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
request.GET,
|
||||
flow_slug=self.token.provider.authorization_flow.slug,
|
||||
)
|
||||
except FlowNonApplicableException:
|
||||
LOGGER.warning("Flow not applicable to user")
|
||||
return None
|
||||
plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
|
||||
request.session[SESSION_KEY_PLAN] = plan
|
||||
return redirect_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
request.GET,
|
||||
flow_slug=token.provider.authorization_flow.slug,
|
||||
)
|
||||
|
||||
|
||||
class DeviceEntryView(View):
|
||||
class DeviceEntryView(PolicyAccessView):
|
||||
"""View used to initiate the device-code flow, url entered by endusers"""
|
||||
|
||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||
@ -94,7 +88,9 @@ class DeviceEntryView(View):
|
||||
LOGGER.info("Brand has no device code flow configured", brand=brand)
|
||||
return HttpResponse(status=404)
|
||||
if QS_KEY_CODE in request.GET:
|
||||
validation = validate_code(request.GET[QS_KEY_CODE], request)
|
||||
validation = CodeValidatorView(request.GET[QS_KEY_CODE], request=request).dispatch(
|
||||
request
|
||||
)
|
||||
if validation:
|
||||
return validation
|
||||
LOGGER.info("Got code from query parameter but no matching token found")
|
||||
@ -131,7 +127,7 @@ class OAuthDeviceCodeChallengeResponse(ChallengeResponse):
|
||||
|
||||
def validate_code(self, code: int) -> HttpResponse | None:
|
||||
"""Validate code and save the returned http response"""
|
||||
response = validate_code(code, self.stage.request)
|
||||
response = CodeValidatorView(code, request=self.stage.request).dispatch(self.stage.request)
|
||||
if not response:
|
||||
raise ValidationError(_("Invalid code"), "invalid")
|
||||
return response
|
||||
|
@ -0,0 +1,44 @@
|
||||
# Generated by Django 5.0.4 on 2024-05-01 15:32
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_saml", "0013_samlprovider_default_relay_state"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="samlprovider",
|
||||
name="digest_algorithm",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("http://www.w3.org/2000/09/xmldsig#sha1", "SHA1"),
|
||||
("http://www.w3.org/2001/04/xmlenc#sha256", "SHA256"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#sha384", "SHA384"),
|
||||
("http://www.w3.org/2001/04/xmlenc#sha512", "SHA512"),
|
||||
],
|
||||
default="http://www.w3.org/2001/04/xmlenc#sha256",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="samlprovider",
|
||||
name="signature_algorithm",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("http://www.w3.org/2000/09/xmldsig#rsa-sha1", "RSA-SHA1"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", "RSA-SHA256"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", "RSA-SHA384"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", "RSA-SHA512"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", "ECDSA-SHA1"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", "ECDSA-SHA256"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", "ECDSA-SHA384"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", "ECDSA-SHA512"),
|
||||
("http://www.w3.org/2000/09/xmldsig#dsa-sha1", "DSA-SHA1"),
|
||||
],
|
||||
default="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
|
||||
),
|
||||
),
|
||||
]
|
@ -11,6 +11,10 @@ from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
from authentik.sources.saml.processors.constants import (
|
||||
DSA_SHA1,
|
||||
ECDSA_SHA1,
|
||||
ECDSA_SHA256,
|
||||
ECDSA_SHA384,
|
||||
ECDSA_SHA512,
|
||||
RSA_SHA1,
|
||||
RSA_SHA256,
|
||||
RSA_SHA384,
|
||||
@ -92,8 +96,7 @@ class SAMLProvider(Provider):
|
||||
),
|
||||
)
|
||||
|
||||
digest_algorithm = models.CharField(
|
||||
max_length=50,
|
||||
digest_algorithm = models.TextField(
|
||||
choices=(
|
||||
(SHA1, _("SHA1")),
|
||||
(SHA256, _("SHA256")),
|
||||
@ -102,13 +105,16 @@ class SAMLProvider(Provider):
|
||||
),
|
||||
default=SHA256,
|
||||
)
|
||||
signature_algorithm = models.CharField(
|
||||
max_length=50,
|
||||
signature_algorithm = models.TextField(
|
||||
choices=(
|
||||
(RSA_SHA1, _("RSA-SHA1")),
|
||||
(RSA_SHA256, _("RSA-SHA256")),
|
||||
(RSA_SHA384, _("RSA-SHA384")),
|
||||
(RSA_SHA512, _("RSA-SHA512")),
|
||||
(ECDSA_SHA1, _("ECDSA-SHA1")),
|
||||
(ECDSA_SHA256, _("ECDSA-SHA256")),
|
||||
(ECDSA_SHA384, _("ECDSA-SHA384")),
|
||||
(ECDSA_SHA512, _("ECDSA-SHA512")),
|
||||
(DSA_SHA1, _("DSA-SHA1")),
|
||||
),
|
||||
default=RSA_SHA256,
|
||||
|
@ -7,13 +7,14 @@ from lxml import etree # nosec
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||
from authentik.crypto.builder import PrivateKeyAlg
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
from authentik.lib.xml import lxml_from_string
|
||||
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
|
||||
from authentik.providers.saml.processors.metadata import MetadataProcessor
|
||||
from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser
|
||||
from authentik.sources.saml.processors.constants import NS_MAP, NS_SAML_METADATA
|
||||
from authentik.sources.saml.processors.constants import ECDSA_SHA256, NS_MAP, NS_SAML_METADATA
|
||||
|
||||
|
||||
class TestServiceProviderMetadataParser(TestCase):
|
||||
@ -107,12 +108,41 @@ class TestServiceProviderMetadataParser(TestCase):
|
||||
load_fixture("fixtures/cert.xml").replace("/apps/user_saml", "")
|
||||
)
|
||||
|
||||
def test_signature(self):
|
||||
"""Test signature validation"""
|
||||
def test_signature_rsa(self):
|
||||
"""Test signature validation (RSA)"""
|
||||
provider = SAMLProvider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=self.flow,
|
||||
signing_kp=create_test_cert(),
|
||||
signing_kp=create_test_cert(PrivateKeyAlg.RSA),
|
||||
)
|
||||
Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=provider,
|
||||
)
|
||||
request = self.factory.get("/")
|
||||
metadata = MetadataProcessor(provider, request).build_entity_descriptor()
|
||||
|
||||
root = fromstring(metadata.encode())
|
||||
xmlsec.tree.add_ids(root, ["ID"])
|
||||
signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP)
|
||||
signature_node = signature_nodes[0]
|
||||
ctx = xmlsec.SignatureContext()
|
||||
key = xmlsec.Key.from_memory(
|
||||
provider.signing_kp.certificate_data,
|
||||
xmlsec.constants.KeyDataFormatCertPem,
|
||||
None,
|
||||
)
|
||||
ctx.key = key
|
||||
ctx.verify(signature_node)
|
||||
|
||||
def test_signature_ecdsa(self):
|
||||
"""Test signature validation (ECDSA)"""
|
||||
provider = SAMLProvider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=self.flow,
|
||||
signing_kp=create_test_cert(PrivateKeyAlg.ECDSA),
|
||||
signature_algorithm=ECDSA_SHA256,
|
||||
)
|
||||
Application.objects.create(
|
||||
name=generate_id(),
|
||||
|
@ -41,7 +41,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
|
||||
if not scim_group:
|
||||
self.logger.debug("Group does not exist in SCIM, skipping")
|
||||
return None
|
||||
response = self._request("DELETE", f"/Groups/{scim_group.id}")
|
||||
response = self._request("DELETE", f"/Groups/{scim_group.scim_id}")
|
||||
scim_group.delete()
|
||||
return response
|
||||
|
||||
@ -89,7 +89,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.id,
|
||||
value=user.scim_id,
|
||||
)
|
||||
)
|
||||
if members:
|
||||
@ -107,16 +107,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"])
|
||||
scim_id = response.get("id")
|
||||
if not scim_id or scim_id == "":
|
||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||
SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id)
|
||||
|
||||
def _update(self, group: Group, connection: SCIMGroup):
|
||||
"""Update existing group"""
|
||||
scim_group = self.to_scim(group)
|
||||
scim_group.id = connection.id
|
||||
scim_group.id = connection.scim_id
|
||||
try:
|
||||
return self._request(
|
||||
"PUT",
|
||||
f"/Groups/{scim_group.id}",
|
||||
f"/Groups/{connection.scim_id}",
|
||||
json=scim_group.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
@ -185,13 +188,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
|
||||
return
|
||||
user_ids = list(
|
||||
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
|
||||
"id", flat=True
|
||||
"scim_id", flat=True
|
||||
)
|
||||
)
|
||||
if len(user_ids) < 1:
|
||||
return
|
||||
self._patch(
|
||||
scim_group.id,
|
||||
scim_group.scim_id,
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
@ -211,13 +214,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
|
||||
return
|
||||
user_ids = list(
|
||||
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
|
||||
"id", flat=True
|
||||
"scim_id", flat=True
|
||||
)
|
||||
)
|
||||
if len(user_ids) < 1:
|
||||
return
|
||||
self._patch(
|
||||
scim_group.id,
|
||||
scim_group.scim_id,
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
||||
|
@ -9,13 +9,14 @@ from pydanticscim.service_provider import (
|
||||
)
|
||||
from pydanticscim.user import User as BaseUser
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
|
||||
|
||||
class User(BaseUser):
|
||||
"""Modified User schema with added externalId field"""
|
||||
|
||||
schemas: list[str] = [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
]
|
||||
schemas: list[str] = [SCIM_USER_SCHEMA]
|
||||
externalId: str | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
@ -23,9 +24,7 @@ class User(BaseUser):
|
||||
class Group(BaseGroup):
|
||||
"""Modified Group schema with added externalId field"""
|
||||
|
||||
schemas: list[str] = [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:Group",
|
||||
]
|
||||
schemas: list[str] = [SCIM_GROUP_SCHEMA]
|
||||
externalId: str | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
|
@ -34,7 +34,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
|
||||
if not scim_user:
|
||||
self.logger.debug("User does not exist in SCIM, skipping")
|
||||
return None
|
||||
response = self._request("DELETE", f"/Users/{scim_user.id}")
|
||||
response = self._request("DELETE", f"/Users/{scim_user.scim_id}")
|
||||
scim_user.delete()
|
||||
return response
|
||||
|
||||
@ -85,15 +85,18 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"])
|
||||
scim_id = response.get("id")
|
||||
if not scim_id or scim_id == "":
|
||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||
SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
|
||||
|
||||
def _update(self, user: User, connection: SCIMUser):
|
||||
"""Update existing user"""
|
||||
scim_user = self.to_scim(user)
|
||||
scim_user.id = connection.id
|
||||
scim_user.id = connection.scim_id
|
||||
self._request(
|
||||
"PUT",
|
||||
f"/Users/{connection.id}",
|
||||
f"/Users/{connection.scim_id}",
|
||||
json=scim_user.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
|
@ -3,7 +3,7 @@
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
from authentik.providers.scim.tasks import scim_task_wrapper
|
||||
from authentik.tenants.management import TenantCommand
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -21,4 +21,4 @@ class Command(TenantCommand):
|
||||
if not provider:
|
||||
LOGGER.warning("Provider does not exist", name=provider_name)
|
||||
continue
|
||||
scim_sync.delay(provider.pk).get()
|
||||
scim_task_wrapper(provider.pk).get()
|
||||
|
@ -0,0 +1,76 @@
|
||||
# Generated by Django 5.0.4 on 2024-05-03 12:38
|
||||
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
from django.apps.registry import Apps
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
from authentik.lib.migrations import progress_bar
|
||||
|
||||
|
||||
def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser")
|
||||
SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup")
|
||||
db_alias = schema_editor.connection.alias
|
||||
print("\nFixing primary key for SCIM users, this might take a couple of minutes...")
|
||||
for user in progress_bar(SCIMUser.objects.using(db_alias).all()):
|
||||
SCIMUser.objects.using(db_alias).filter(
|
||||
pk=user.pk, user=user.user_id, provider=user.provider_id
|
||||
).update(scim_id=user.pk, id=uuid.uuid4())
|
||||
|
||||
print("\nFixing primary key for SCIM groups, this might take a couple of minutes...")
|
||||
for group in progress_bar(SCIMGroup.objects.using(db_alias).all()):
|
||||
SCIMGroup.objects.using(db_alias).filter(
|
||||
pk=group.pk, group=group.group_id, provider=group.provider_id
|
||||
).update(scim_id=group.pk, id=uuid.uuid4())
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_providers_scim",
|
||||
"0001_squashed_0006_rename_parent_group_scimprovider_filter_group",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="scimgroup",
|
||||
name="scim_id",
|
||||
field=models.TextField(default="temp"),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimuser",
|
||||
name="scim_id",
|
||||
field=models.TextField(default="temp"),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.RunPython(fix_scim_user_group_pk),
|
||||
migrations.AlterField(
|
||||
model_name="scimgroup",
|
||||
name="id",
|
||||
field=models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimuser",
|
||||
name="id",
|
||||
field=models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()),
|
||||
migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimgroup",
|
||||
unique_together={("scim_id", "group", "provider")},
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimuser",
|
||||
unique_together={("scim_id", "user", "provider")},
|
||||
),
|
||||
]
|
@ -1,5 +1,7 @@
|
||||
"""SCIM Provider models"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
@ -97,26 +99,28 @@ class SCIMMapping(PropertyMapping):
|
||||
class SCIMUser(models.Model):
|
||||
"""Mapping of a user and provider to a SCIM user ID"""
|
||||
|
||||
id = models.TextField(primary_key=True)
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
scim_id = models.TextField()
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
unique_together = (("id", "user", "provider"),)
|
||||
unique_together = (("scim_id", "user", "provider"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM User {self.user.username} to {self.provider.name}"
|
||||
return f"SCIM User {self.user_id} to {self.provider_id}"
|
||||
|
||||
|
||||
class SCIMGroup(models.Model):
|
||||
"""Mapping of a group and provider to a SCIM user ID"""
|
||||
|
||||
id = models.TextField(primary_key=True)
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
scim_id = models.TextField()
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
unique_together = (("id", "group", "provider"),)
|
||||
unique_together = (("scim_id", "group", "provider"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM Group {self.group.name} to {self.provider.name}"
|
||||
return f"SCIM Group {self.group_id} to {self.provider_id}"
|
||||
|
@ -9,7 +9,7 @@ from structlog.stdlib import get_logger
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_sync
|
||||
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_task_wrapper
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -17,7 +17,7 @@ LOGGER = get_logger()
|
||||
@receiver(post_save, sender=SCIMProvider)
|
||||
def post_save_provider(sender: type[Model], instance, created: bool, **_):
|
||||
"""Trigger sync when SCIM provider is saved"""
|
||||
scim_sync.delay(instance.pk)
|
||||
scim_task_wrapper(instance.pk)
|
||||
|
||||
|
||||
@receiver(post_save, sender=User)
|
||||
|
@ -38,7 +38,23 @@ def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient:
|
||||
def scim_sync_all():
|
||||
"""Run sync for all providers"""
|
||||
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
|
||||
scim_sync.delay(provider.pk)
|
||||
scim_task_wrapper(provider.pk)
|
||||
|
||||
|
||||
def scim_task_wrapper(provider_pk: int):
|
||||
"""Wrap scim_sync to set the correct timeouts"""
|
||||
provider: SCIMProvider = SCIMProvider.objects.filter(
|
||||
pk=provider_pk, backchannel_application__isnull=False
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
|
||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
||||
time_limit = soft_time_limit * 1.5
|
||||
return scim_sync.apply_async(
|
||||
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
||||
)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@ -60,7 +76,7 @@ def scim_sync(self: SystemTask, provider_pk: int) -> None:
|
||||
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
|
||||
self.soft_time_limit = self.time_limit = (
|
||||
users_paginator.count + groups_paginator.count
|
||||
users_paginator.num_pages + groups_paginator.num_pages
|
||||
) * PAGE_TIMEOUT
|
||||
with allow_join_result():
|
||||
try:
|
||||
|
@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
from authentik.providers.scim.tasks import scim_task_wrapper
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ class SCIMMembershipTests(TestCase):
|
||||
)
|
||||
|
||||
self.configure()
|
||||
scim_sync.delay(self.provider.pk).get()
|
||||
scim_task_wrapper(self.provider.pk).get()
|
||||
|
||||
self.assertEqual(mocker.call_count, 6)
|
||||
self.assertEqual(mocker.request_history[0].method, "GET")
|
||||
@ -169,7 +169,7 @@ class SCIMMembershipTests(TestCase):
|
||||
)
|
||||
|
||||
self.configure()
|
||||
scim_sync.delay(self.provider.pk).get()
|
||||
scim_task_wrapper(self.provider.pk).get()
|
||||
|
||||
self.assertEqual(mocker.call_count, 6)
|
||||
self.assertEqual(mocker.request_history[0].method, "GET")
|
||||
|
@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
from authentik.providers.scim.tasks import scim_task_wrapper
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
|
||||
@ -88,6 +88,72 @@ class SCIMUserTests(TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_user_create_different_provider_same_id(self, mock: Mocker):
|
||||
"""Test user creation with multiple providers that happen
|
||||
to return the same object ID"""
|
||||
# Create duplicate provider
|
||||
provider: SCIMProvider = SCIMProvider.objects.create(
|
||||
name=generate_id(),
|
||||
url="https://localhost",
|
||||
token=generate_id(),
|
||||
exclude_users_service_account=True,
|
||||
)
|
||||
app: Application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
)
|
||||
app.backchannel_providers.add(provider)
|
||||
provider.property_mappings.add(
|
||||
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
|
||||
)
|
||||
provider.property_mappings_group.add(
|
||||
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
|
||||
)
|
||||
|
||||
scim_id = generate_id()
|
||||
mock.get(
|
||||
"https://localhost/ServiceProviderConfig",
|
||||
json={},
|
||||
)
|
||||
mock.post(
|
||||
"https://localhost/Users",
|
||||
json={
|
||||
"id": scim_id,
|
||||
},
|
||||
)
|
||||
uid = generate_id()
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
self.assertEqual(mock.call_count, 4)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
self.assertJSONEqual(
|
||||
mock.request_history[1].body,
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"active": True,
|
||||
"emails": [
|
||||
{
|
||||
"primary": True,
|
||||
"type": "other",
|
||||
"value": f"{uid}@goauthentik.io",
|
||||
}
|
||||
],
|
||||
"externalId": user.uid,
|
||||
"name": {
|
||||
"familyName": uid,
|
||||
"formatted": f"{uid} {uid}",
|
||||
"givenName": uid,
|
||||
},
|
||||
"displayName": f"{uid} {uid}",
|
||||
"userName": uid,
|
||||
},
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_user_create_update(self, mock: Mocker):
|
||||
"""Test user creation and update"""
|
||||
@ -236,7 +302,7 @@ class SCIMUserTests(TestCase):
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
|
||||
scim_sync.delay(self.provider.pk).get()
|
||||
scim_task_wrapper(self.provider.pk).get()
|
||||
|
||||
self.assertEqual(mock.call_count, 5)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
|
@ -5,13 +5,13 @@ import os
|
||||
from collections import OrderedDict
|
||||
from hashlib import sha512
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from celery.schedules import crontab
|
||||
from django.conf import ImproperlyConfigured
|
||||
from sentry_sdk import set_tag
|
||||
|
||||
from authentik import ENV_GIT_HASH_KEY, __version__
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.config import CONFIG, redis_url
|
||||
from authentik.lib.logging import get_logger_config, structlog_configure
|
||||
from authentik.lib.sentry import sentry_init
|
||||
from authentik.lib.utils.reflection import get_env
|
||||
@ -195,25 +195,15 @@ REST_FRAMEWORK = {
|
||||
},
|
||||
}
|
||||
|
||||
_redis_protocol_prefix = "redis://"
|
||||
_redis_celery_tls_requirements = ""
|
||||
if CONFIG.get_bool("redis.tls", False):
|
||||
_redis_protocol_prefix = "rediss://"
|
||||
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}"
|
||||
_redis_url = (
|
||||
f"{_redis_protocol_prefix}"
|
||||
f"{quote_plus(CONFIG.get('redis.username'))}:"
|
||||
f"{quote_plus(CONFIG.get('redis.password'))}@"
|
||||
f"{quote_plus(CONFIG.get('redis.host'))}:"
|
||||
f"{CONFIG.get_int('redis.port')}"
|
||||
)
|
||||
|
||||
CACHES = {
|
||||
"default": {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": CONFIG.get("cache.url") or f"{_redis_url}/{CONFIG.get('redis.db')}",
|
||||
"LOCATION": CONFIG.get("cache.url") or redis_url(CONFIG.get("redis.db")),
|
||||
"TIMEOUT": CONFIG.get_int("cache.timeout", 300),
|
||||
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
||||
"OPTIONS": {
|
||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||
},
|
||||
"KEY_PREFIX": "authentik_cache",
|
||||
"KEY_FUNCTION": "django_tenants.cache.make_key",
|
||||
"REVERSE_KEY_FUNCTION": "django_tenants.cache.reverse_key",
|
||||
@ -222,7 +212,15 @@ CACHES = {
|
||||
DJANGO_REDIS_SCAN_ITERSIZE = 1000
|
||||
DJANGO_REDIS_IGNORE_EXCEPTIONS = True
|
||||
DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
|
||||
match CONFIG.get("session_storage", "cache"):
|
||||
case "cache":
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
|
||||
case "db":
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.db"
|
||||
case _:
|
||||
raise ImproperlyConfigured(
|
||||
"Invalid session_storage setting, allowed values are db and cache"
|
||||
)
|
||||
SESSION_SERIALIZER = "authentik.root.sessions.pickle.PickleSerializer"
|
||||
SESSION_CACHE_ALIAS = "default"
|
||||
# Configured via custom SessionMiddleware
|
||||
@ -276,7 +274,7 @@ CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
|
||||
"CONFIG": {
|
||||
"hosts": [CONFIG.get("channel.url", f"{_redis_url}/{CONFIG.get('redis.db')}")],
|
||||
"hosts": [CONFIG.get("channel.url") or redis_url(CONFIG.get("redis.db"))],
|
||||
"prefix": "authentik_channels_",
|
||||
},
|
||||
},
|
||||
@ -376,11 +374,15 @@ CELERY = {
|
||||
"beat_scheduler": "authentik.tenants.scheduler:TenantAwarePersistentScheduler",
|
||||
"task_create_missing_queues": True,
|
||||
"task_default_queue": "authentik",
|
||||
"broker_url": CONFIG.get("broker.url")
|
||||
or f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"broker_transport_options": CONFIG.get_dict_from_b64_json("broker.transport_options"),
|
||||
"result_backend": CONFIG.get("result_backend.url")
|
||||
or f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")),
|
||||
"result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")),
|
||||
"broker_transport_options": CONFIG.get_dict_from_b64_json(
|
||||
"broker.transport_options", {"retry_policy": {"timeout": 5.0}}
|
||||
),
|
||||
"result_backend_transport_options": CONFIG.get_dict_from_b64_json(
|
||||
"result_backend.transport_options", {"retry_policy": {"timeout": 5.0}}
|
||||
),
|
||||
"redis_retry_on_timeout": True,
|
||||
}
|
||||
|
||||
# Sentry integration
|
||||
|
@ -34,7 +34,7 @@ def mock_ad_connection(password: str) -> Connection:
|
||||
"objectSid": "unique-test-group",
|
||||
"objectClass": "group",
|
||||
"distinguishedName": "cn=group1,ou=groups,dc=goauthentik,dc=io",
|
||||
"member": ["cn=user0,ou=users,dc=goauthentik,dc=io"],
|
||||
"member": ["cn=user,ou=users,dc=goauthentik,dc=io"],
|
||||
},
|
||||
)
|
||||
# Group without SID
|
||||
@ -47,7 +47,7 @@ def mock_ad_connection(password: str) -> Connection:
|
||||
},
|
||||
)
|
||||
connection.strategy.add_entry(
|
||||
"cn=user0,ou=users,dc=goauthentik,dc=io",
|
||||
"cn=user0,ou=foo,ou=users,dc=goauthentik,dc=io",
|
||||
{
|
||||
"userPassword": password,
|
||||
"sAMAccountName": "user0_sn",
|
||||
|
@ -55,7 +55,7 @@ class LDAPSyncTests(TestCase):
|
||||
)
|
||||
connection.assert_called_with(
|
||||
connection_kwargs={
|
||||
"user": "cn=user0,ou=users,dc=goauthentik,dc=io",
|
||||
"user": "cn=user0,ou=foo,ou=users,dc=goauthentik,dc=io",
|
||||
"password": LDAP_PASSWORD,
|
||||
}
|
||||
)
|
||||
|
@ -80,7 +80,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||
access_token_url = self.source.source_type.access_token_url or ""
|
||||
if self.source.source_type.urls_customizable and self.source.access_token_url:
|
||||
access_token_url = self.source.access_token_url
|
||||
response = self.session.request(
|
||||
response = self.do_request(
|
||||
"post", access_token_url, data=args, headers=self._default_headers, **request_kwargs
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
37
authentik/sources/oauth/tests/test_type_apple.py
Normal file
37
authentik/sources/oauth/tests/test_type_apple.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Apple Type tests"""
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import dummy_get_response
|
||||
from authentik.root.middleware import SessionMiddleware
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.registry import registry
|
||||
|
||||
|
||||
class TestTypeApple(TestCase):
|
||||
"""OAuth Source tests"""
|
||||
|
||||
def setUp(self):
|
||||
self.source = OAuthSource.objects.create(
|
||||
name="test",
|
||||
slug="test",
|
||||
provider_type="apple",
|
||||
authorization_url="",
|
||||
profile_url="",
|
||||
consumer_key=generate_id(),
|
||||
)
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def test_login_challenge(self):
|
||||
"""Test login_challenge"""
|
||||
request = self.factory.get("/")
|
||||
request.user = get_anonymous_user()
|
||||
|
||||
middleware = SessionMiddleware(dummy_get_response)
|
||||
middleware.process_request(request)
|
||||
request.session.save()
|
||||
oauth_type = registry.find_type("apple")
|
||||
challenge = oauth_type().login_challenge(self.source, request)
|
||||
self.assertTrue(challenge.is_valid(raise_exception=True))
|
@ -125,7 +125,7 @@ class AppleType(SourceType):
|
||||
)
|
||||
args = apple_client.get_redirect_args()
|
||||
return AppleLoginChallenge(
|
||||
instance={
|
||||
data={
|
||||
"client_id": apple_client.get_client_id(),
|
||||
"scope": "name email",
|
||||
"redirect_uri": args["redirect_uri"],
|
||||
|
@ -66,7 +66,7 @@ class PlexSource(Source):
|
||||
icon = static("authentik/sources/plex.svg")
|
||||
return UILoginButton(
|
||||
challenge=PlexAuthenticationChallenge(
|
||||
{
|
||||
data={
|
||||
"type": ChallengeTypes.NATIVE.value,
|
||||
"component": "ak-source-plex",
|
||||
"client_id": self.client_id,
|
||||
|
@ -40,6 +40,11 @@ class TestPlexSource(TestCase):
|
||||
slug="test",
|
||||
)
|
||||
|
||||
def test_login_challenge(self):
|
||||
"""Test login_challenge"""
|
||||
ui_login_button = self.source.ui_login_button(None)
|
||||
self.assertTrue(ui_login_button.challenge.is_valid(raise_exception=True))
|
||||
|
||||
def test_get_user_info(self):
|
||||
"""Test get_user_info"""
|
||||
token = generate_key()
|
||||
|
@ -0,0 +1,44 @@
|
||||
# Generated by Django 5.0.4 on 2024-05-01 15:44
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_sources_saml", "0013_samlsource_verification_kp_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="samlsource",
|
||||
name="digest_algorithm",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("http://www.w3.org/2000/09/xmldsig#sha1", "SHA1"),
|
||||
("http://www.w3.org/2001/04/xmlenc#sha256", "SHA256"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#sha384", "SHA384"),
|
||||
("http://www.w3.org/2001/04/xmlenc#sha512", "SHA512"),
|
||||
],
|
||||
default="http://www.w3.org/2001/04/xmlenc#sha256",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="samlsource",
|
||||
name="signature_algorithm",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("http://www.w3.org/2000/09/xmldsig#rsa-sha1", "RSA-SHA1"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", "RSA-SHA256"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", "RSA-SHA384"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", "RSA-SHA512"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", "ECDSA-SHA1"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", "ECDSA-SHA256"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", "ECDSA-SHA384"),
|
||||
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", "ECDSA-SHA512"),
|
||||
("http://www.w3.org/2000/09/xmldsig#dsa-sha1", "DSA-SHA1"),
|
||||
],
|
||||
default="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
|
||||
),
|
||||
),
|
||||
]
|
@ -15,6 +15,10 @@ from authentik.flows.models import Flow
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
from authentik.sources.saml.processors.constants import (
|
||||
DSA_SHA1,
|
||||
ECDSA_SHA1,
|
||||
ECDSA_SHA256,
|
||||
ECDSA_SHA384,
|
||||
ECDSA_SHA512,
|
||||
RSA_SHA1,
|
||||
RSA_SHA256,
|
||||
RSA_SHA384,
|
||||
@ -143,8 +147,7 @@ class SAMLSource(Source):
|
||||
verbose_name=_("Signing Keypair"),
|
||||
)
|
||||
|
||||
digest_algorithm = models.CharField(
|
||||
max_length=50,
|
||||
digest_algorithm = models.TextField(
|
||||
choices=(
|
||||
(SHA1, _("SHA1")),
|
||||
(SHA256, _("SHA256")),
|
||||
@ -153,13 +156,16 @@ class SAMLSource(Source):
|
||||
),
|
||||
default=SHA256,
|
||||
)
|
||||
signature_algorithm = models.CharField(
|
||||
max_length=50,
|
||||
signature_algorithm = models.TextField(
|
||||
choices=(
|
||||
(RSA_SHA1, _("RSA-SHA1")),
|
||||
(RSA_SHA256, _("RSA-SHA256")),
|
||||
(RSA_SHA384, _("RSA-SHA384")),
|
||||
(RSA_SHA512, _("RSA-SHA512")),
|
||||
(ECDSA_SHA1, _("ECDSA-SHA1")),
|
||||
(ECDSA_SHA256, _("ECDSA-SHA256")),
|
||||
(ECDSA_SHA384, _("ECDSA-SHA384")),
|
||||
(ECDSA_SHA512, _("ECDSA-SHA512")),
|
||||
(DSA_SHA1, _("DSA-SHA1")),
|
||||
),
|
||||
default=RSA_SHA256,
|
||||
|
@ -26,9 +26,16 @@ SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
|
||||
DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
|
||||
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
|
||||
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.2
|
||||
RSA_SHA256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"
|
||||
RSA_SHA384 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384"
|
||||
RSA_SHA512 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512"
|
||||
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.6
|
||||
ECDSA_SHA1 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1"
|
||||
ECDSA_SHA224 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha224"
|
||||
ECDSA_SHA256 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256"
|
||||
ECDSA_SHA384 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384"
|
||||
ECDSA_SHA512 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512"
|
||||
|
||||
SHA1 = "http://www.w3.org/2000/09/xmldsig#sha1"
|
||||
SHA256 = "http://www.w3.org/2001/04/xmlenc#sha256"
|
||||
@ -41,6 +48,11 @@ SIGN_ALGORITHM_TRANSFORM_MAP = {
|
||||
RSA_SHA256: xmlsec.constants.TransformRsaSha256,
|
||||
RSA_SHA384: xmlsec.constants.TransformRsaSha384,
|
||||
RSA_SHA512: xmlsec.constants.TransformRsaSha512,
|
||||
ECDSA_SHA1: xmlsec.constants.TransformEcdsaSha1,
|
||||
ECDSA_SHA224: xmlsec.constants.TransformEcdsaSha224,
|
||||
ECDSA_SHA256: xmlsec.constants.TransformEcdsaSha256,
|
||||
ECDSA_SHA384: xmlsec.constants.TransformEcdsaSha384,
|
||||
ECDSA_SHA512: xmlsec.constants.TransformEcdsaSha512,
|
||||
}
|
||||
|
||||
DIGEST_ALGORITHM_TRANSLATION_MAP = {
|
||||
|
@ -7,7 +7,6 @@ from rest_framework.viewsets import ModelViewSet
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.tokens import TokenSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.models import Token, TokenIntents, User, UserTypes
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
|
||||
|
||||
@ -27,25 +26,6 @@ class SCIMSourceSerializer(SourceSerializer):
|
||||
return relative_url
|
||||
return self.context["request"].build_absolute_uri(relative_url)
|
||||
|
||||
def create(self, validated_data):
|
||||
instance: SCIMSource = super().create(validated_data)
|
||||
identifier = f"ak-source-scim-{instance.pk}"
|
||||
user = User.objects.create(
|
||||
username=identifier,
|
||||
name=f"SCIM Source {instance.name} Service-Account",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
token = Token.objects.create(
|
||||
user=user,
|
||||
identifier=identifier,
|
||||
intent=TokenIntents.INTENT_API,
|
||||
expiring=False,
|
||||
managed=f"goauthentik.io/sources/scim/{instance.pk}",
|
||||
)
|
||||
instance.token = token
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
|
||||
model = SCIMSource
|
||||
|
@ -1,12 +1,13 @@
|
||||
"""Authentik SCIM app config"""
|
||||
|
||||
from django.apps import AppConfig
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
|
||||
|
||||
class AuthentikSourceSCIMConfig(AppConfig):
|
||||
class AuthentikSourceSCIMConfig(ManagedAppConfig):
|
||||
"""authentik SCIM Source app config"""
|
||||
|
||||
name = "authentik.sources.scim"
|
||||
label = "authentik_sources_scim"
|
||||
verbose_name = "authentik Sources.SCIM"
|
||||
mountpoint = "source/scim/"
|
||||
default = True
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""SCIM Source"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
@ -14,6 +16,12 @@ class SCIMSource(Source):
|
||||
|
||||
token = models.ForeignKey(Token, on_delete=models.CASCADE, null=True, default=None)
|
||||
|
||||
@property
|
||||
def service_account_identifier(self) -> str:
|
||||
if not self.pk:
|
||||
self.pk = uuid4()
|
||||
return f"ak-source-scim-{self.pk}"
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
"""Return component used to edit this object"""
|
||||
@ -52,7 +60,7 @@ class SCIMSourceUser(SerializerModel):
|
||||
unique_together = (("id", "user", "source"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM User {self.user.username} to {self.source.name}"
|
||||
return f"SCIM User {self.user_id} to {self.source_id}"
|
||||
|
||||
|
||||
class SCIMSourceGroup(SerializerModel):
|
||||
@ -73,4 +81,4 @@ class SCIMSourceGroup(SerializerModel):
|
||||
unique_together = (("id", "group", "source"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM Group {self.group.name} to {self.source.name}"
|
||||
return f"SCIM Group {self.group_id} to {self.source_id}"
|
||||
|
41
authentik/sources/scim/signals.py
Normal file
41
authentik/sources/scim/signals.py
Normal file
@ -0,0 +1,41 @@
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import pre_delete, pre_save
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.core.models import USER_PATH_SYSTEM_PREFIX, Token, TokenIntents, User, UserTypes
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
|
||||
USER_PATH_SOURCE_SCIM = USER_PATH_SYSTEM_PREFIX + "/sources/scim"
|
||||
|
||||
|
||||
@receiver(pre_save, sender=SCIMSource)
|
||||
def scim_source_pre_save(sender: type[Model], instance: SCIMSource, **_):
|
||||
"""Create service account before source is saved"""
|
||||
# .service_account_identifier will auto-assign a primary key uuid to the source
|
||||
# if none is set yet, just so we can get the identifier before we save
|
||||
identifier = instance.service_account_identifier
|
||||
user = User.objects.create(
|
||||
username=identifier,
|
||||
name=f"SCIM Source {instance.name} Service-Account",
|
||||
type=UserTypes.INTERNAL_SERVICE_ACCOUNT,
|
||||
path=USER_PATH_SOURCE_SCIM,
|
||||
)
|
||||
token = Token.objects.create(
|
||||
user=user,
|
||||
identifier=identifier,
|
||||
intent=TokenIntents.INTENT_API,
|
||||
expiring=False,
|
||||
managed=f"goauthentik.io/sources/scim/{instance.pk}",
|
||||
)
|
||||
instance.token = token
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=SCIMSource)
|
||||
def scim_source_pre_delete(sender: type[Model], instance: SCIMSource, **_):
|
||||
"""Delete SCIM Source service account before deleting source"""
|
||||
Token.objects.filter(
|
||||
identifier=instance.service_account_identifier, intent=TokenIntents.INTENT_API
|
||||
).delete()
|
||||
User.objects.filter(
|
||||
username=instance.service_account_identifier, type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
).delete()
|
@ -14,27 +14,13 @@ class TestSCIMAuth(APITestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
self.token = Token.objects.create(
|
||||
user=self.user,
|
||||
identifier=generate_id(),
|
||||
intent=TokenIntents.INTENT_API,
|
||||
)
|
||||
self.token2 = Token.objects.create(
|
||||
user=self.user,
|
||||
identifier=generate_id(),
|
||||
intent=TokenIntents.INTENT_API,
|
||||
)
|
||||
self.token3 = Token.objects.create(
|
||||
user=self.user,
|
||||
identifier=generate_id(),
|
||||
intent=TokenIntents.INTENT_API,
|
||||
)
|
||||
self.source = SCIMSource.objects.create(
|
||||
name=generate_id(), slug=generate_id(), token=self.token
|
||||
)
|
||||
self.source2 = SCIMSource.objects.create(
|
||||
name=generate_id(), slug=generate_id(), token=self.token2
|
||||
)
|
||||
self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
|
||||
self.source2 = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
|
||||
|
||||
def test_auth_ok(self):
|
||||
"""Test successful auth"""
|
||||
@ -45,7 +31,7 @@ class TestSCIMAuth(APITestCase):
|
||||
"source_slug": self.source.slug,
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@ -71,7 +57,7 @@ class TestSCIMAuth(APITestCase):
|
||||
"source_slug": self.source.slug,
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token2.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source2.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
# Token for no source
|
||||
|
@ -3,8 +3,6 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
|
||||
@ -13,14 +11,9 @@ class TestSCIMResourceTypes(APITestCase):
|
||||
"""Test SCIM ResourceTypes view"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
self.token = Token.objects.create(
|
||||
user=self.user,
|
||||
identifier=generate_id(),
|
||||
intent=TokenIntents.INTENT_API,
|
||||
)
|
||||
self.source = SCIMSource.objects.create(
|
||||
name=generate_id(), slug=generate_id(), token=self.token
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
)
|
||||
|
||||
def test_resource_type(self):
|
||||
@ -32,7 +25,7 @@ class TestSCIMResourceTypes(APITestCase):
|
||||
"source_slug": self.source.slug,
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@ -46,7 +39,7 @@ class TestSCIMResourceTypes(APITestCase):
|
||||
"resource_type": "ServiceProviderConfig",
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@ -60,6 +53,6 @@ class TestSCIMResourceTypes(APITestCase):
|
||||
"resource_type": "foo",
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
@ -3,8 +3,6 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
|
||||
@ -13,15 +11,7 @@ class TestSCIMSchemas(APITestCase):
|
||||
"""Test SCIM Schema view"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
self.token = Token.objects.create(
|
||||
user=self.user,
|
||||
identifier=generate_id(),
|
||||
intent=TokenIntents.INTENT_API,
|
||||
)
|
||||
self.source = SCIMSource.objects.create(
|
||||
name=generate_id(), slug=generate_id(), token=self.token
|
||||
)
|
||||
self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
|
||||
|
||||
def test_schema(self):
|
||||
"""Test full schema view"""
|
||||
@ -32,7 +22,7 @@ class TestSCIMSchemas(APITestCase):
|
||||
"source_slug": self.source.slug,
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@ -46,7 +36,7 @@ class TestSCIMSchemas(APITestCase):
|
||||
"schema_uri": "urn:ietf:params:scim:schemas:core:2.0:Meta",
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@ -60,6 +50,6 @@ class TestSCIMSchemas(APITestCase):
|
||||
"schema_uri": "foo",
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
@ -3,8 +3,6 @@
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
|
||||
@ -13,14 +11,9 @@ class TestSCIMServiceProviderConfig(APITestCase):
|
||||
"""Test SCIM ServiceProviderConfig view"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
self.token = Token.objects.create(
|
||||
user=self.user,
|
||||
identifier=generate_id(),
|
||||
intent=TokenIntents.INTENT_API,
|
||||
)
|
||||
self.source = SCIMSource.objects.create(
|
||||
name=generate_id(), slug=generate_id(), token=self.token
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
@ -32,6 +25,6 @@ class TestSCIMServiceProviderConfig(APITestCase):
|
||||
"source_slug": self.source.slug,
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
27
authentik/sources/scim/tests/test_signals.py
Normal file
27
authentik/sources/scim/tests/test_signals.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Test SCIM Source creation"""
|
||||
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Token, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
|
||||
|
||||
class TestSCIMSignals(APITestCase):
|
||||
"""Test SCIM Signals view"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.uid = generate_id()
|
||||
|
||||
def test_create(self) -> None:
|
||||
source = SCIMSource.objects.create(name=self.uid, slug=self.uid)
|
||||
self.assertIsNotNone(source.token)
|
||||
self.assertIsNotNone(source.token.user)
|
||||
|
||||
def test_delete(self):
|
||||
self.test_create()
|
||||
source = SCIMSource.objects.filter(slug=self.uid).first()
|
||||
identifier = source.service_account_identifier
|
||||
source.delete()
|
||||
self.assertFalse(User.objects.filter(username=identifier).exists())
|
||||
self.assertFalse(Token.objects.filter(identifier=identifier).exists())
|
@ -6,8 +6,8 @@ from uuid import uuid4
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
|
||||
from authentik.sources.scim.models import SCIMSource, SCIMSourceUser
|
||||
@ -18,15 +18,7 @@ class TestSCIMUsers(APITestCase):
|
||||
"""Test SCIM User view"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
self.token = Token.objects.create(
|
||||
user=self.user,
|
||||
identifier=generate_id(),
|
||||
intent=TokenIntents.INTENT_API,
|
||||
)
|
||||
self.source = SCIMSource.objects.create(
|
||||
name=generate_id(), slug=generate_id(), token=self.token
|
||||
)
|
||||
self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
|
||||
|
||||
def test_user_list(self):
|
||||
"""Test full user list"""
|
||||
@ -37,15 +29,16 @@ class TestSCIMUsers(APITestCase):
|
||||
"source_slug": self.source.slug,
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_user_list_single(self):
|
||||
"""Test full user list (single user)"""
|
||||
user = create_test_user()
|
||||
SCIMSourceUser.objects.create(
|
||||
source=self.source,
|
||||
user=self.user,
|
||||
user=user,
|
||||
id=str(uuid4()),
|
||||
)
|
||||
response = self.client.get(
|
||||
@ -53,16 +46,17 @@ class TestSCIMUsers(APITestCase):
|
||||
"authentik_sources_scim:v2-users",
|
||||
kwargs={
|
||||
"source_slug": self.source.slug,
|
||||
"user_id": str(self.user.uuid),
|
||||
"user_id": str(user.uuid),
|
||||
},
|
||||
),
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
SCIMUserSchema.model_validate_json(response.content, strict=True)
|
||||
|
||||
def test_user_create(self):
|
||||
"""Test user create"""
|
||||
user = create_test_user()
|
||||
ext_id = generate_id()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
@ -78,13 +72,18 @@ class TestSCIMUsers(APITestCase):
|
||||
"emails": [
|
||||
{
|
||||
"primary": True,
|
||||
"value": self.user.email,
|
||||
"value": user.email,
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.token.key}",
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(SCIMSourceUser.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
).exists()
|
||||
)
|
||||
|
@ -13,6 +13,7 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
|
||||
from authentik.sources.scim.models import SCIMSourceGroup
|
||||
from authentik.sources.scim.views.v2.base import SCIMView
|
||||
@ -26,9 +27,11 @@ class GroupsView(SCIMView):
|
||||
def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
|
||||
"""Convert Group to SCIM data"""
|
||||
payload = SCIMGroupModel(
|
||||
schemas=[SCIM_USER_SCHEMA],
|
||||
id=str(scim_group.group.pk),
|
||||
externalId=scim_group.id,
|
||||
displayName=scim_group.group.name,
|
||||
members=[],
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"location": self.request.build_absolute_uri(
|
||||
@ -42,28 +45,24 @@ class GroupsView(SCIMView):
|
||||
),
|
||||
},
|
||||
)
|
||||
return payload.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
)
|
||||
for member in scim_group.group.users.order_by("pk"):
|
||||
member: User
|
||||
payload.members.append(GroupMember(value=str(member.uuid)))
|
||||
return payload.model_dump(mode="json", exclude_unset=True)
|
||||
|
||||
def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
|
||||
"""List Group handler"""
|
||||
base_query = SCIMSourceGroup.objects.select_related("group").prefetch_related(
|
||||
"group__users"
|
||||
)
|
||||
if group_id:
|
||||
connection = (
|
||||
SCIMSourceGroup.objects.filter(source=self.source, group__group_uuid=group_id)
|
||||
.select_related("group")
|
||||
.first()
|
||||
)
|
||||
connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
|
||||
if not connection:
|
||||
raise Http404
|
||||
return Response(self.group_to_scim(connection))
|
||||
connections = (
|
||||
SCIMSourceGroup.objects.filter(source=self.source)
|
||||
.select_related("group")
|
||||
.order_by("pk")
|
||||
base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
|
||||
)
|
||||
connections = connections.filter(self.filter_parse(request))
|
||||
page = self.paginate_query(connections)
|
||||
return Response(
|
||||
{
|
||||
@ -79,6 +78,8 @@ class GroupsView(SCIMView):
|
||||
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict):
|
||||
"""Partial update a group"""
|
||||
group = connection.group if connection else Group()
|
||||
if _group := Group.objects.filter(name=data.get("displayName")).first():
|
||||
group = _group
|
||||
if "displayName" in data:
|
||||
group.name = data.get("displayName")
|
||||
if group.name == "":
|
||||
|
@ -11,6 +11,7 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserModel
|
||||
from authentik.sources.scim.models import SCIMSourceUser
|
||||
from authentik.sources.scim.views.v2.base import SCIMView
|
||||
@ -33,6 +34,7 @@ class UsersView(SCIMView):
|
||||
def user_to_scim(self, scim_user: SCIMSourceUser) -> dict:
|
||||
"""Convert User to SCIM data"""
|
||||
payload = SCIMUserModel(
|
||||
schemas=[SCIM_USER_SCHEMA],
|
||||
id=str(scim_user.user.uuid),
|
||||
externalId=scim_user.id,
|
||||
userName=scim_user.user.username,
|
||||
@ -62,10 +64,7 @@ class UsersView(SCIMView):
|
||||
),
|
||||
},
|
||||
)
|
||||
final_payload = payload.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
)
|
||||
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||
final_payload.update(scim_user.attributes)
|
||||
return final_payload
|
||||
|
||||
@ -99,6 +98,8 @@ class UsersView(SCIMView):
|
||||
def update_user(self, connection: SCIMSourceUser | None, data: QueryDict):
|
||||
"""Partial update a user"""
|
||||
user = connection.user if connection else User()
|
||||
if _user := User.objects.filter(username=data.get("userName")).first():
|
||||
user = _user
|
||||
user.path = self.source.get_user_path()
|
||||
if "userName" in data:
|
||||
user.username = data.get("userName")
|
||||
|
@ -96,7 +96,7 @@ class DuoDevice(SerializerModel, Device):
|
||||
return DuoDeviceSerializer
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name) or str(self.user)
|
||||
return str(self.name) or str(self.user_id)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Duo Device")
|
||||
|
@ -221,7 +221,7 @@ class SMSDevice(SerializerModel, SideChannelDevice):
|
||||
return valid
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name) or str(self.user)
|
||||
return str(self.name) or str(self.user_id)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("SMS Device")
|
||||
|
@ -20,7 +20,10 @@ class WebAuthnDeviceSerializer(ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = WebAuthnDevice
|
||||
fields = ["pk", "name", "created_on", "device_type"]
|
||||
fields = ["pk", "name", "created_on", "device_type", "aaguid"]
|
||||
extra_kwargs = {
|
||||
"aaguid": {"read_only": True},
|
||||
}
|
||||
|
||||
|
||||
class WebAuthnDeviceViewSet(
|
||||
|
@ -0,0 +1,168 @@
|
||||
# Generated by Django 5.0.4 on 2024-04-18 11:29
|
||||
|
||||
import django.db.models.deletion
|
||||
import django.utils.timezone
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
replaces = [
|
||||
("authentik_stages_authenticator_webauthn", "0001_initial"),
|
||||
("authentik_stages_authenticator_webauthn", "0002_default_setup_flow"),
|
||||
("authentik_stages_authenticator_webauthn", "0003_webauthndevice_confirmed"),
|
||||
("authentik_stages_authenticator_webauthn", "0004_auto_20210304_1850"),
|
||||
(
|
||||
"authentik_stages_authenticator_webauthn",
|
||||
"0005_authenticatewebauthnstage_user_verification",
|
||||
),
|
||||
(
|
||||
"authentik_stages_authenticator_webauthn",
|
||||
"0006_authenticatewebauthnstage_authenticator_attachment_and_more",
|
||||
),
|
||||
(
|
||||
"authentik_stages_authenticator_webauthn",
|
||||
"0007_rename_last_used_on_webauthndevice_last_t",
|
||||
),
|
||||
("authentik_stages_authenticator_webauthn", "0008_alter_webauthndevice_credential_id"),
|
||||
("authentik_stages_authenticator_webauthn", "0009_authenticatewebauthnstage_friendly_name"),
|
||||
(
|
||||
"authentik_stages_authenticator_webauthn",
|
||||
"0010_webauthndevicetype_authenticatorwebauthnstage_and_more",
|
||||
),
|
||||
("authentik_stages_authenticator_webauthn", "0011_webauthndevice_aaguid"),
|
||||
]
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_flows", "0016_auto_20201202_1307"),
|
||||
("authentik_flows", "0027_auto_20231028_1424"),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="WebAuthnDeviceType",
|
||||
fields=[
|
||||
("aaguid", models.UUIDField(primary_key=True, serialize=False, unique=True)),
|
||||
("description", models.TextField()),
|
||||
("icon", models.TextField(null=True)),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "WebAuthn Device type",
|
||||
"verbose_name_plural": "WebAuthn Device types",
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="AuthenticatorWebAuthnStage",
|
||||
fields=[
|
||||
(
|
||||
"stage_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_flows.stage",
|
||||
),
|
||||
),
|
||||
(
|
||||
"configure_flow",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
help_text="Flow used by an authenticated user to configure this Stage. If empty, user will not be able to configure this stage.",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="authentik_flows.flow",
|
||||
),
|
||||
),
|
||||
(
|
||||
"user_verification",
|
||||
models.TextField(
|
||||
choices=[
|
||||
("required", "Required"),
|
||||
("preferred", "Preferred"),
|
||||
("discouraged", "Discouraged"),
|
||||
],
|
||||
default="preferred",
|
||||
),
|
||||
),
|
||||
(
|
||||
"authenticator_attachment",
|
||||
models.TextField(
|
||||
choices=[("platform", "Platform"), ("cross-platform", "Cross Platform")],
|
||||
default=None,
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"resident_key_requirement",
|
||||
models.TextField(
|
||||
choices=[
|
||||
("discouraged", "Discouraged"),
|
||||
("preferred", "Preferred"),
|
||||
("required", "Required"),
|
||||
],
|
||||
default="preferred",
|
||||
),
|
||||
),
|
||||
("friendly_name", models.TextField(null=True)),
|
||||
(
|
||||
"device_type_restrictions",
|
||||
models.ManyToManyField(
|
||||
blank=True, to="authentik_stages_authenticator_webauthn.webauthndevicetype"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "WebAuthn Authenticator Setup Stage",
|
||||
"verbose_name_plural": "WebAuthn Authenticator Setup Stages",
|
||||
},
|
||||
bases=("authentik_flows.stage", models.Model),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="WebAuthnDevice",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
|
||||
),
|
||||
),
|
||||
("name", models.TextField(max_length=200)),
|
||||
("credential_id", models.TextField(unique=True)),
|
||||
("public_key", models.TextField()),
|
||||
("sign_count", models.IntegerField(default=0)),
|
||||
("rp_id", models.CharField(max_length=253)),
|
||||
("created_on", models.DateTimeField(auto_now_add=True)),
|
||||
("last_t", models.DateTimeField(default=django.utils.timezone.now)),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
(
|
||||
"confirmed",
|
||||
models.BooleanField(default=True, help_text="Is this device ready for use?"),
|
||||
),
|
||||
(
|
||||
"device_type",
|
||||
models.ForeignKey(
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_DEFAULT,
|
||||
to="authentik_stages_authenticator_webauthn.webauthndevicetype",
|
||||
),
|
||||
),
|
||||
("aaguid", models.TextField(default="00000000-0000-0000-0000-000000000000")),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "WebAuthn Device",
|
||||
"verbose_name_plural": "WebAuthn Devices",
|
||||
},
|
||||
),
|
||||
]
|
@ -0,0 +1,21 @@
|
||||
# Generated by Django 5.0.4 on 2024-04-18 11:27
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_stages_authenticator_webauthn",
|
||||
"0010_webauthndevicetype_authenticatorwebauthnstage_and_more",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="webauthndevice",
|
||||
name="aaguid",
|
||||
field=models.TextField(default="00000000-0000-0000-0000-000000000000"),
|
||||
),
|
||||
]
|
@ -132,6 +132,7 @@ class WebAuthnDevice(SerializerModel, Device):
|
||||
created_on = models.DateTimeField(auto_now_add=True)
|
||||
last_t = models.DateTimeField(default=now)
|
||||
|
||||
aaguid = models.TextField(default=UNKNOWN_DEVICE_TYPE_AAGUID)
|
||||
device_type = models.ForeignKey(
|
||||
"WebAuthnDeviceType", on_delete=models.SET_DEFAULT, null=True, default=None
|
||||
)
|
||||
@ -154,7 +155,7 @@ class WebAuthnDevice(SerializerModel, Device):
|
||||
return WebAuthnDeviceSerializer
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name) or str(self.user)
|
||||
return str(self.name) or str(self.user_id)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("WebAuthn Device")
|
||||
|
@ -126,10 +126,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
|
||||
if authenticator_attachment:
|
||||
authenticator_attachment = AuthenticatorAttachment(str(authenticator_attachment))
|
||||
|
||||
attestation = AttestationConveyancePreference.DIRECT
|
||||
if stage.device_type_restrictions.exists():
|
||||
attestation = AttestationConveyancePreference.ENTERPRISE
|
||||
|
||||
registration_options: PublicKeyCredentialCreationOptions = generate_registration_options(
|
||||
rp_id=get_rp_id(self.request),
|
||||
rp_name=self.request.brand.branding_title,
|
||||
@ -141,7 +137,7 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
|
||||
user_verification=UserVerificationRequirement(str(stage.user_verification)),
|
||||
authenticator_attachment=authenticator_attachment,
|
||||
),
|
||||
attestation=attestation,
|
||||
attestation=AttestationConveyancePreference.DIRECT,
|
||||
)
|
||||
|
||||
self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge
|
||||
@ -180,6 +176,7 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
|
||||
sign_count=webauthn_credential.sign_count,
|
||||
rp_id=get_rp_id(self.request),
|
||||
device_type=device_type,
|
||||
aaguid=webauthn_credential.aaguid,
|
||||
)
|
||||
else:
|
||||
return self.executor.stage_invalid("Device with Credential ID already exists.")
|
||||
|
@ -65,7 +65,7 @@ class UserConsent(SerializerModel, ExpiringModel):
|
||||
return UserConsentSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"User Consent {self.application} by {self.user}"
|
||||
return f"User Consent {self.application_id} by {self.user_id}"
|
||||
|
||||
class Meta:
|
||||
unique_together = (("user", "application", "permissions"),)
|
||||
|
@ -79,7 +79,7 @@ class Invitation(SerializerModel, ExpiringModel):
|
||||
return InvitationSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"Invitation {str(self.invite_uuid)} created by {self.created_by}"
|
||||
return f"Invitation {str(self.invite_uuid)} created by {self.created_by_id}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Invitation")
|
||||
|
@ -150,22 +150,26 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||
return attrs
|
||||
|
||||
|
||||
def username_field_validator_factory() -> Callable[[PromptChallenge, str], Any]:
|
||||
def username_field_validator_factory() -> Callable[[PromptChallengeResponse, str], Any]:
|
||||
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
|
||||
|
||||
def username_field_validator(_: PromptChallenge, value: str) -> Any:
|
||||
def username_field_validator(self: PromptChallengeResponse, value: str) -> Any:
|
||||
"""Check for duplicate usernames"""
|
||||
if User.objects.filter(username=value).exists():
|
||||
pending_user = self.stage.get_pending_user()
|
||||
query = User.objects.all()
|
||||
if pending_user.pk:
|
||||
query = query.exclude(username=pending_user.username)
|
||||
if query.filter(username=value).exists():
|
||||
raise ValidationError("Username is already taken.")
|
||||
return value
|
||||
|
||||
return username_field_validator
|
||||
|
||||
|
||||
def password_single_validator_factory() -> Callable[[PromptChallenge, str], Any]:
|
||||
def password_single_validator_factory() -> Callable[[PromptChallengeResponse, str], Any]:
|
||||
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
|
||||
|
||||
def password_single_clean(self: PromptChallenge, value: str) -> Any:
|
||||
def password_single_clean(self: PromptChallengeResponse, value: str) -> Any:
|
||||
"""Send password validation signals for e.g. LDAP Source"""
|
||||
password_validate.send(sender=self, password=value, plan_context=self.plan.context)
|
||||
return value
|
||||
|
@ -9,6 +9,7 @@ from django.utils.translation import gettext as _
|
||||
from rest_framework.fields import BooleanField, CharField
|
||||
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.events.middleware import audit_ignore
|
||||
from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, PLAN_CONTEXT_SOURCE
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
@ -95,11 +96,14 @@ class UserLoginStageView(ChallengeStageView):
|
||||
self.logger.warning("User is not active, login will not work.")
|
||||
delta = self.set_session_duration(remember)
|
||||
self.set_session_ip()
|
||||
login(
|
||||
self.request,
|
||||
user,
|
||||
backend=backend,
|
||||
)
|
||||
# the `user_logged_in` signal will update the user to write the `last_login` field
|
||||
# which we don't want to log as we already have a dedicated login event
|
||||
with audit_ignore():
|
||||
login(
|
||||
self.request,
|
||||
user,
|
||||
backend=backend,
|
||||
)
|
||||
self.logger.debug(
|
||||
"Logged in",
|
||||
backend=backend,
|
||||
|
@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.0.4 on 2024-05-01 15:32
|
||||
|
||||
import authentik.lib.utils.time
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_tenants", "0002_tenant_default_token_duration_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="tenant",
|
||||
name="default_token_duration",
|
||||
field=models.TextField(
|
||||
default="days=1",
|
||||
help_text="Default token duration",
|
||||
validators=[authentik.lib.utils.time.timedelta_string_validator],
|
||||
),
|
||||
),
|
||||
]
|
@ -23,7 +23,7 @@ LOGGER = get_logger()
|
||||
|
||||
VALID_SCHEMA_NAME = re.compile(r"^t_[a-z0-9]{1,61}$")
|
||||
|
||||
DEFAULT_TOKEN_DURATION = "minutes=30" # nosec
|
||||
DEFAULT_TOKEN_DURATION = "days=1" # nosec
|
||||
DEFAULT_TOKEN_LENGTH = 60
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
from tenant_schemas_celery.scheduler import (
|
||||
TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler,
|
||||
)
|
||||
from tenant_schemas_celery.scheduler import TenantAwareScheduleEntry
|
||||
|
||||
|
||||
class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler):
|
||||
@ -11,3 +12,11 @@ class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler):
|
||||
@classmethod
|
||||
def get_queryset(cls):
|
||||
return super().get_queryset().filter(ready=True)
|
||||
|
||||
def apply_entry(self, entry: TenantAwareScheduleEntry, producer=None):
|
||||
# https://github.com/maciej-gol/tenant-schemas-celery/blob/master/tenant_schemas_celery/scheduler.py#L85
|
||||
# When (as by default) no tenant schemas are set, the public schema is excluded
|
||||
# so we need to explicitly include it here, otherwise the task is not executed
|
||||
if entry.tenant_schemas is None:
|
||||
entry.tenant_schemas = self.get_queryset().values_list("schema_name", flat=True)
|
||||
return super().apply_entry(entry, producer)
|
||||
|
@ -16,7 +16,7 @@ entries:
|
||||
placeholder: Username
|
||||
placeholder_expression: false
|
||||
required: true
|
||||
type: text
|
||||
type: username
|
||||
field_key: username
|
||||
label: Username
|
||||
identifiers:
|
||||
|
@ -2,7 +2,7 @@
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://goauthentik.io/blueprints/schema.json",
|
||||
"type": "object",
|
||||
"title": "authentik 2024.2.2 Blueprint schema",
|
||||
"title": "authentik 2024.4.3 Blueprint schema",
|
||||
"required": [
|
||||
"version",
|
||||
"entries"
|
||||
@ -4131,6 +4131,10 @@
|
||||
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha384",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha512",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512",
|
||||
"http://www.w3.org/2000/09/xmldsig#dsa-sha1"
|
||||
],
|
||||
"title": "Signature algorithm"
|
||||
@ -4935,6 +4939,10 @@
|
||||
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha384",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha512",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384",
|
||||
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512",
|
||||
"http://www.w3.org/2000/09/xmldsig#dsa-sha1"
|
||||
],
|
||||
"title": "Signature algorithm"
|
||||
|
@ -11,7 +11,6 @@ entries:
|
||||
name: "authentik default LDAP Mapping: DN to User Path"
|
||||
object_field: "path"
|
||||
expression: |
|
||||
dn = ldap.get("distinguishedName")
|
||||
path_elements = []
|
||||
for pair in dn.split(","):
|
||||
attr, _, value = pair.partition("=")
|
||||
|
@ -32,7 +32,7 @@ services:
|
||||
volumes:
|
||||
- redis:/data
|
||||
server:
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.2.2}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.3}
|
||||
restart: unless-stopped
|
||||
command: server
|
||||
environment:
|
||||
@ -53,7 +53,7 @@ services:
|
||||
- postgresql
|
||||
- redis
|
||||
worker:
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.2.2}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.3}
|
||||
restart: unless-stopped
|
||||
command: worker
|
||||
environment:
|
||||
|
10
go.mod
10
go.mod
@ -1,15 +1,13 @@
|
||||
module goauthentik.io
|
||||
|
||||
go 1.22
|
||||
|
||||
toolchain go1.22.0
|
||||
go 1.22.2
|
||||
|
||||
require (
|
||||
beryju.io/ldap v0.1.0
|
||||
github.com/coreos/go-oidc v2.2.1+incompatible
|
||||
github.com/getsentry/sentry-go v0.27.0
|
||||
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
|
||||
github.com/go-ldap/ldap/v3 v3.4.7
|
||||
github.com/go-ldap/ldap/v3 v3.4.8
|
||||
github.com/go-openapi/runtime v0.28.0
|
||||
github.com/go-openapi/strfmt v0.23.0
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
@ -30,7 +28,7 @@ require (
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/wwt/guac v1.3.2
|
||||
goauthentik.io/api/v3 v3.2024022.11
|
||||
goauthentik.io/api/v3 v3.2024023.2
|
||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/sync v0.7.0
|
||||
@ -75,7 +73,7 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.24.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.24.0 // indirect
|
||||
golang.org/x/crypto v0.21.0 // indirect
|
||||
golang.org/x/net v0.22.0 // indirect
|
||||
golang.org/x/net v0.23.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
|
11
go.sum
11
go.sum
@ -84,8 +84,8 @@ github.com/go-http-utils/fresh v0.0.0-20161124030543-7231e26a4b27 h1:O6yi4xa9b2D
|
||||
github.com/go-http-utils/fresh v0.0.0-20161124030543-7231e26a4b27/go.mod h1:AYvN8omj7nKLmbcXS2dyABYU6JB1Lz1bHmkkq1kf4I4=
|
||||
github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a h1:v6zMvHuY9yue4+QkG/HQ/W67wvtQmWJ4SDo9aK/GIno=
|
||||
github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a/go.mod h1:I79BieaU4fxrw4LMXby6q5OS9XnoR9UIKLOzDFjUmuw=
|
||||
github.com/go-ldap/ldap/v3 v3.4.7 h1:3Hbd7mIB1qjd3Ra59fI3JYea/t5kykFu2CVHBca9koE=
|
||||
github.com/go-ldap/ldap/v3 v3.4.7/go.mod h1:qS3Sjlu76eHfHGpUdWkAXQTw4beih+cHsco2jXlIXrk=
|
||||
github.com/go-ldap/ldap/v3 v3.4.8 h1:loKJyspcRezt2Q3ZRMq2p/0v8iOurlmeXDPw6fikSvQ=
|
||||
github.com/go-ldap/ldap/v3 v3.4.8/go.mod h1:qS3Sjlu76eHfHGpUdWkAXQTw4beih+cHsco2jXlIXrk=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
|
||||
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@ -294,8 +294,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
|
||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
||||
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
|
||||
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
|
||||
goauthentik.io/api/v3 v3.2024022.11 h1:MlsaBwyMM9NtDvZcoaWvuNznPHXA0a5olnDLyr24REA=
|
||||
goauthentik.io/api/v3 v3.2024022.11/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
goauthentik.io/api/v3 v3.2024023.2 h1:lSVaZAKTpsDhtw11wnkGjPalkDzv9H2VKEJllBi2aXs=
|
||||
goauthentik.io/api/v3 v3.2024023.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
@ -373,8 +373,9 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
|
||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
@ -25,13 +25,14 @@ type Config struct {
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Host string `yaml:"host" env:"HOST, overwrite"`
|
||||
Port int `yaml:"port" env:"PORT, overwrite"`
|
||||
DB int `yaml:"db" env:"DB, overwrite"`
|
||||
Username string `yaml:"username" env:"USERNAME, overwrite"`
|
||||
Password string `yaml:"password" env:"PASSWORD, overwrite"`
|
||||
TLS bool `yaml:"tls" env:"TLS, overwrite"`
|
||||
TLSReqs string `yaml:"tls_reqs" env:"TLS_REQS, overwrite"`
|
||||
Host string `yaml:"host" env:"HOST, overwrite"`
|
||||
Port int `yaml:"port" env:"PORT, overwrite"`
|
||||
DB int `yaml:"db" env:"DB, overwrite"`
|
||||
Username string `yaml:"username" env:"USERNAME, overwrite"`
|
||||
Password string `yaml:"password" env:"PASSWORD, overwrite"`
|
||||
TLS bool `yaml:"tls" env:"TLS, overwrite"`
|
||||
TLSReqs string `yaml:"tls_reqs" env:"TLS_REQS, overwrite"`
|
||||
TLSCaCert *string `yaml:"tls_ca_certs" env:"TLS_CA_CERT, overwrite"`
|
||||
}
|
||||
|
||||
type ListenConfig struct {
|
||||
|
@ -29,4 +29,4 @@ func UserAgent() string {
|
||||
return fmt.Sprintf("authentik@%s", FullVersion())
|
||||
}
|
||||
|
||||
const VERSION = "2024.2.2"
|
||||
const VERSION = "2024.4.3"
|
||||
|
@ -2,6 +2,8 @@ package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
@ -19,6 +21,7 @@ import (
|
||||
"goauthentik.io/internal/outpost/proxyv2/codecs"
|
||||
"goauthentik.io/internal/outpost/proxyv2/constants"
|
||||
"goauthentik.io/internal/outpost/proxyv2/redisstore"
|
||||
"goauthentik.io/internal/utils"
|
||||
)
|
||||
|
||||
const RedisKeyPrefix = "authentik_proxy_session_"
|
||||
@ -31,11 +34,40 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
|
||||
maxAge = int(*t) + 1
|
||||
}
|
||||
if a.isEmbedded {
|
||||
var tls *tls.Config
|
||||
if config.Get().Redis.TLS {
|
||||
tls = utils.GetTLSConfig()
|
||||
switch strings.ToLower(config.Get().Redis.TLSReqs) {
|
||||
case "none":
|
||||
case "false":
|
||||
tls.InsecureSkipVerify = true
|
||||
case "required":
|
||||
break
|
||||
}
|
||||
ca := config.Get().Redis.TLSCaCert
|
||||
if ca != nil {
|
||||
// Get the SystemCertPool, continue with an empty pool on error
|
||||
rootCAs, _ := x509.SystemCertPool()
|
||||
if rootCAs == nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
}
|
||||
certs, err := os.ReadFile(*ca)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Fatalf("Failed to append %s to RootCAs", *ca)
|
||||
}
|
||||
// Append our cert to the system pool
|
||||
if ok := rootCAs.AppendCertsFromPEM(certs); !ok {
|
||||
a.log.Println("No certs appended, using system certs only")
|
||||
}
|
||||
tls.RootCAs = rootCAs
|
||||
}
|
||||
}
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", config.Get().Redis.Host, config.Get().Redis.Port),
|
||||
Username: config.Get().Redis.Username,
|
||||
Password: config.Get().Redis.Password,
|
||||
DB: config.Get().Redis.DB,
|
||||
Addr: fmt.Sprintf("%s:%d", config.Get().Redis.Host, config.Get().Redis.Port),
|
||||
Username: config.Get().Redis.Username,
|
||||
Password: config.Get().Redis.Password,
|
||||
DB: config.Get().Redis.DB,
|
||||
TLSConfig: tls,
|
||||
})
|
||||
|
||||
// New default RedisStore
|
||||
|
@ -54,7 +54,7 @@ function cleanup {
|
||||
}
|
||||
|
||||
function prepare_debug {
|
||||
poetry install --no-ansi --no-interaction
|
||||
VIRTUAL_ENV=/ak-root/venv poetry install --no-ansi --no-interaction
|
||||
touch /unittest.xml
|
||||
chown authentik:authentik /unittest.xml
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user