Compare commits

..

6 Commits

Author SHA1 Message Date
ad652bde38 more consistent nesting
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-08-24 20:26:40 +02:00
9e813bf404 refactor some more
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-08-24 18:09:41 +02:00
11e708a45a inconsistent naming
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-08-24 16:00:08 +02:00
1e6e4a0bbc refactor from self.executor.current_stage to make nesting easier
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-08-24 15:59:31 +02:00
2149e81d8f base
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-08-24 15:43:33 +02:00
98dc794597 unrelated
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-08-24 15:42:22 +02:00
181 changed files with 6558 additions and 16396 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 2024.8.4
current_version = 2024.6.4
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -29,9 +29,9 @@ outputs:
imageTags:
description: "Docker image tags"
value: ${{ steps.ev.outputs.imageTags }}
attestImageNames:
description: "Docker image names used for attestation"
value: ${{ steps.ev.outputs.attestImageNames }}
imageNames:
description: "Docker image names"
value: ${{ steps.ev.outputs.imageNames }}
imageMainTag:
description: "Docker image main tag"
value: ${{ steps.ev.outputs.imageMainTag }}

View File

@ -51,24 +51,15 @@ else:
]
image_main_tag = image_tags[0].split(":")[-1]
def get_attest_image_names(image_with_tags: list[str]):
"""Attestation only for GHCR"""
image_tags = []
for image_name in set(name.split(":")[0] for name in image_with_tags):
if not image_name.startswith("ghcr.io"):
continue
image_tags.append(image_name)
return ",".join(set(image_tags))
image_tags_rendered = ",".join(image_tags)
image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags))
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print(f"shouldBuild={should_build}", file=_output)
print(f"sha={sha}", file=_output)
print(f"version={version}", file=_output)
print(f"prerelease={prerelease}", file=_output)
print(f"imageTags={','.join(image_tags)}", file=_output)
print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output)
print(f"imageTags={image_tags_rendered}", file=_output)
print(f"imageNames={image_names_rendered}", file=_output)
print(f"imageMainTag={image_main_tag}", file=_output)
print(f"imageMainName={image_tags[0]}", file=_output)

View File

@ -261,7 +261,7 @@ jobs:
id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
pr-comment:

View File

@ -115,7 +115,7 @@ jobs:
id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-binary:

View File

@ -92,4 +92,4 @@ jobs:
run: make gen-client-ts
- name: test
working-directory: web/
run: npm run test || exit 0
run: npm run test

View File

@ -58,7 +58,7 @@ jobs:
- uses: actions/attest-build-provenance@v1
id: attest
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-outpost:
@ -122,7 +122,7 @@ jobs:
- uses: actions/attest-build-provenance@v1
id: attest
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-outpost-binary:

View File

@ -205,7 +205,7 @@ gen: gen-build gen-client-ts
web-build: web-install ## Build the Authentik UI
cd web && npm run build
web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
web: web-lint-fix web-lint web-check-compile web-test ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
web-install: ## Install the necessary libraries to build the Authentik UI
cd web && npm ci

View File

@ -2,7 +2,7 @@
from os import environ
__version__ = "2024.8.4"
__version__ = "2024.6.4"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -51,11 +51,9 @@ class BlueprintInstanceSerializer(ModelSerializer):
context = self.instance.context if self.instance else {}
valid, logs = Importer.from_string(content, context).validate()
if not valid:
text_logs = "\n".join([x["event"] for x in logs])
raise ValidationError(
[
_("Failed to validate blueprint"),
*[f"- {x.event}" for x in logs],
]
_("Failed to validate blueprint: {logs}".format_map({"logs": text_logs}))
)
return content

View File

@ -78,5 +78,5 @@ class TestBlueprintsV1API(APITestCase):
self.assertEqual(res.status_code, 400)
self.assertJSONEqual(
res.content.decode(),
{"content": ["Failed to validate blueprint", "- Invalid blueprint version"]},
{"content": ["Failed to validate blueprint: Invalid blueprint version"]},
)

View File

@ -429,7 +429,7 @@ class Importer:
orig_import = deepcopy(self._import)
if self._import.version != 1:
self.logger.warning("Invalid blueprint version")
return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)]
return False, [{"event": "Invalid blueprint version"}]
with (
transaction_rollback(),
capture_logs() as logs,

View File

@ -30,10 +30,8 @@ from authentik.core.api.utils import (
PassiveSerializer,
)
from authentik.core.expression.evaluator import PropertyMappingEvaluator
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group, PropertyMapping, User
from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.api.exec import PolicyTestSerializer
from authentik.rbac.decorators import permission_required
@ -164,15 +162,12 @@ class PropertyMappingViewSet(
response_data = {"successful": True, "result": ""}
try:
result = mapping.evaluate(dry_run=True, **context)
result = mapping.evaluate(**context)
response_data["result"] = dumps(
sanitize_item(result), indent=(4 if format_result else None)
)
except PropertyMappingExpressionException as exc:
response_data["result"] = exception_to_string(exc.exc)
response_data["successful"] = False
except Exception as exc:
response_data["result"] = exception_to_string(exc)
response_data["result"] = str(exc)
response_data["successful"] = False
response = PropertyMappingTestResultSerializer(response_data)
return Response(response.data)

View File

@ -678,13 +678,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
if not request.tenant.impersonation:
LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401)
user_to_be = self.get_object()
# Check both object-level perms and global perms
if not request.user.has_perm(
"authentik_core.impersonate", user_to_be
) and not request.user.has_perm("authentik_core.impersonate"):
if not request.user.has_perm("impersonate"):
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return Response(status=401)
user_to_be = self.get_object()
if user_to_be.pk == self.request.user.pk:
LOGGER.debug("User attempted to impersonate themselves", user=request.user)
return Response(status=401)

View File

@ -9,11 +9,10 @@ class Command(TenantCommand):
def add_arguments(self, parser):
parser.add_argument("--type", type=str, required=True)
parser.add_argument("--all", action="store_true", default=False)
parser.add_argument("usernames", nargs="*", type=str)
parser.add_argument("--all", action="store_true")
parser.add_argument("usernames", nargs="+", type=str)
def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"])
qs = (
User.objects.exclude_anonymous()
@ -23,9 +22,6 @@ class Command(TenantCommand):
if options["usernames"] and options["all"]:
self.stderr.write("--all and usernames specified, only one can be specified")
return
if not options["usernames"] and not options["all"]:
self.stderr.write("--all or usernames must be specified")
return
if options["usernames"] and not options["all"]:
qs = qs.filter(username__in=options["usernames"])
updated = qs.update(type=new_type)

View File

@ -466,6 +466,8 @@ class ApplicationQuerySet(QuerySet):
def with_provider(self) -> "QuerySet[Application]":
qs = self.select_related("provider")
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
if LOOKUP_SEP in subclass:
continue
qs = qs.select_related(f"provider__{subclass}")
return qs
@ -543,24 +545,15 @@ class Application(SerializerModel, PolicyBindingModel):
if not self.provider:
return None
candidates = []
base_class = Provider
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
parent = self.provider
for level in subclass.split(LOOKUP_SEP):
try:
parent = getattr(parent, level)
except AttributeError:
break
if parent in candidates:
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
# We don't care about recursion, skip nested models
if LOOKUP_SEP in subclass:
continue
idx = subclass.count(LOOKUP_SEP)
if type(parent) is not base_class:
idx += 1
candidates.insert(idx, parent)
if not candidates:
return None
return candidates[-1]
try:
return getattr(self.provider, subclass)
except AttributeError:
pass
return None
def __str__(self):
return str(self.name)
@ -908,7 +901,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
except ControlFlowException as exc:
raise exc
except Exception as exc:
raise PropertyMappingExpressionException(exc, self) from exc
raise PropertyMappingExpressionException(self, exc) from exc
def __str__(self):
return f"Property Mapping {self.name}"

View File

@ -69,8 +69,8 @@ class MessageStage(StageView):
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Show a pre-configured message after the flow is done"""
message = getattr(self.executor.current_stage, "message", "")
level = getattr(self.executor.current_stage, "level", messages.SUCCESS)
message = getattr(self.current_stage, "message", "")
level = getattr(self.current_stage, "level", messages.SUCCESS)
messages.add_message(
self.request,
level,
@ -486,9 +486,7 @@ class GroupUpdateStage(StageView):
def handle_groups(self) -> bool:
self.source: Source = self.executor.plan.context[PLAN_CONTEXT_SOURCE]
self.user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
self.group_connection_type: GroupSourceConnection = (
self.executor.current_stage.group_connection_type
)
self.group_connection_type: GroupSourceConnection = self.current_stage.group_connection_type
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
PLAN_CONTEXT_SOURCE_GROUPS

View File

@ -9,12 +9,9 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.proxy.models import ProxyProvider
from authentik.providers.saml.models import SAMLProvider
class TestApplicationsAPI(APITestCase):
@ -225,31 +222,3 @@ class TestApplicationsAPI(APITestCase):
],
},
)
def test_get_provider(self):
"""Ensure that proxy providers (at the time of writing that is the only provider
that inherits from another proxy type (OAuth) instead of inheriting from the root
provider class) is correctly looked up and selected from the database"""
slug = generate_id()
provider = ProxyProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)
slug = generate_id()
provider = SAMLProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)

View File

@ -3,10 +3,10 @@
from json import loads
from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user, create_test_user
from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.tenants.utils import get_current_tenant
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
def setUp(self) -> None:
super().setUp()
self.other_user = create_test_user()
self.other_user = User.objects.create(username="to-impersonate")
self.user = create_test_admin_user()
def test_impersonate_simple(self):
@ -44,46 +44,6 @@ class TestImpersonation(APITestCase):
self.assertEqual(response_body["user"]["username"], self.user.username)
self.assertNotIn("original", response_body)
def test_impersonate_global(self):
"""Test impersonation with global permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user)
assign_perm("authentik_core.view_user", new_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_scoped(self):
"""Test impersonation with scoped permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user, self.other_user)
assign_perm("authentik_core.view_user", new_user, self.other_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_denied(self):
"""test impersonation without permissions"""
self.client.force_login(self.other_user)

View File

@ -18,7 +18,7 @@ from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
from authentik.enterprise.models import License
from authentik.enterprise.models import License, LicenseUsageStatus
from authentik.rbac.decorators import permission_required
from authentik.tenants.utils import get_unique_identifier
@ -29,7 +29,7 @@ class EnterpriseRequiredMixin:
def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists"""
if not LicenseKey.cached_summary().status.is_valid:
if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED:
raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs)

View File

@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
"""Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.cached_summary().status.is_valid
return LicenseKey.cached_summary().status

View File

@ -117,13 +117,10 @@ class LicenseKey:
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
options={"verify_exp": check_expiry},
),
)
except PyJWTError:
unverified = decode(jwt, options={"verify_signature": False})
if unverified["aud"] != get_license_aud():
raise ValidationError("Invalid Install ID in license") from None
raise ValidationError("Unable to verify license") from None
return body
@ -137,7 +134,7 @@ class LicenseKey:
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
total.exp = max(total.exp, exp_ts)
total.exp = min(total.exp, exp_ts)
total.license_flags.extend(lic.status.license_flags)
return total

View File

@ -17,7 +17,7 @@ from authentik.flows.challenge import RedirectChallenge
from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import in_memory_stage
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner
from authentik.flows.stage import RedirectStage
from authentik.flows.stage import RedirectStageChallengeView
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.utils.time import timedelta_from_string
from authentik.lib.utils.urls import redirect_with_qs
@ -83,7 +83,7 @@ class RACInterface(InterfaceView):
return super().get_context_data(**kwargs)
class RACFinalStage(RedirectStage):
class RACFinalStage(RedirectStageChallengeView):
"""RAC Connection final stage, set the connection token in the stage"""
endpoint: Endpoint
@ -91,9 +91,9 @@ class RACFinalStage(RedirectStage):
application: Application
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
self.endpoint = self.executor.current_stage.endpoint
self.provider = self.executor.current_stage.provider
self.application = self.executor.current_stage.application
self.endpoint = self.current_stage.endpoint
self.provider = self.current_stage.provider
self.application = self.current_stage.application
# Check policies bound to endpoint directly
engine = PolicyEngine(self.endpoint, self.request.user, self.request)
engine.use_cache = False
@ -132,7 +132,7 @@ class RACFinalStage(RedirectStage):
flow=self.executor.plan.flow_pk,
endpoint=self.endpoint.name,
).from_http(self.request)
self.executor.current_stage.destination = self.request.build_absolute_uri(
self.current_stage.destination = self.request.build_absolute_uri(
reverse("authentik_providers_rac:if-rac", kwargs={"token": str(token.token)})
)
return super().get_challenge(*args, **kwargs)

View File

@ -3,7 +3,7 @@
from datetime import datetime
from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_save
from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver
from django.utils.timezone import get_current_timezone
@ -27,9 +27,3 @@ def post_save_license(sender: type[License], instance: License, **_):
"""Trigger license usage calculation when license is saved"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
enterprise_update_usage.delay()
@receiver(post_delete, sender=License)
def post_delete_license(sender: type[License], instance: License, **_):
"""Clear license cache when license is deleted"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)

View File

@ -21,16 +21,15 @@ from authentik.lib.utils.time import timedelta_from_string
PLAN_CONTEXT_RESUME_TOKEN = "resume_token" # nosec
class SourceStageView(ChallengeStageView):
class SourceStageView(ChallengeStageView[SourceStage]):
"""Suspend the current flow execution and send the user to a source,
after which this flow execution is resumed."""
login_button: UILoginButton
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
current_stage: SourceStage = self.executor.current_stage
source: Source = (
Source.objects.filter(pk=current_stage.source_id).select_subclasses().first()
Source.objects.filter(pk=self.current_stage.source_id).select_subclasses().first()
)
if not source:
self.logger.warning("Source does not exist")
@ -56,11 +55,10 @@ class SourceStageView(ChallengeStageView):
pending_user: User = self.get_pending_user()
if pending_user.is_anonymous or not pending_user.pk:
pending_user = get_anonymous_user()
current_stage: SourceStage = self.executor.current_stage
identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}")
identifier = slugify(f"ak-source-stage-{self.current_stage.name}-{str(uuid4())}")
# Don't check for validity here, we only care if the token exists
tokens = FlowToken.objects.filter(identifier=identifier)
valid_delta = timedelta_from_string(current_stage.resume_timeout)
valid_delta = timedelta_from_string(self.current_stage.resume_timeout)
if not tokens.exists():
return FlowToken.objects.create(
expires=now() + valid_delta,

View File

@ -69,5 +69,8 @@ class NotificationViewSet(
@action(detail=False, methods=["post"])
def mark_all_seen(self, request: Request) -> Response:
"""Mark all the user's notifications as seen"""
Notification.objects.filter(user=request.user, seen=False).update(seen=True)
notifications = Notification.objects.filter(user=request.user)
for notification in notifications:
notification.seen = True
Notification.objects.bulk_update(notifications, ["seen"])
return Response({}, status=204)

View File

@ -49,7 +49,6 @@ from authentik.policies.models import PolicyBindingModel
from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
DISCORD_FIELD_LIMIT = 25
@ -59,11 +58,7 @@ NOTIFICATION_SUMMARY_LENGTH = 75
def default_event_duration():
"""Default duration an Event is saved.
This is used as a fallback when no brand is available"""
try:
tenant = get_current_tenant()
return now() + timedelta_from_string(tenant.event_retention)
except Tenant.DoesNotExist:
return now() + timedelta(days=365)
return now() + timedelta(days=365)
def default_brand():
@ -250,6 +245,12 @@ class Event(SerializerModel, ExpiringModel):
if QS_QUERY in self.context["http_request"]["args"]:
wrapped = self.context["http_request"]["args"][QS_QUERY]
self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
if hasattr(request, "tenant"):
tenant: Tenant = request.tenant
# Because self.created only gets set on save, we can't use it's value here
# hence we set self.created to now and then use it
self.created = now()
self.expires = self.created + timedelta_from_string(tenant.event_retention)
if hasattr(request, "brand"):
brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand))

View File

@ -6,7 +6,6 @@ from django.db.models import Model
from django.test import TestCase
from authentik.core.models import default_token_key
from authentik.events.models import default_event_duration
from authentik.lib.utils.reflection import get_apps
@ -21,7 +20,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable:
allowed = 0
# Token-like objects need to lookup the current tenant to get the default token length
for field in test_model._meta.fields:
if field.default in [default_token_key, default_event_duration]:
if field.default == default_token_key:
allowed += 1
with self.assertNumQueries(allowed):
str(test_model())

View File

@ -2,8 +2,7 @@
from unittest.mock import MagicMock, patch
from django.urls import reverse
from rest_framework.test import APITestCase
from django.test import TestCase
from authentik.core.models import Group, User
from authentik.events.models import (
@ -11,7 +10,6 @@ from authentik.events.models import (
EventAction,
Notification,
NotificationRule,
NotificationSeverity,
NotificationTransport,
NotificationWebhookMapping,
TransportMode,
@ -22,7 +20,7 @@ from authentik.policies.exceptions import PolicyException
from authentik.policies.models import PolicyBinding
class TestEventsNotifications(APITestCase):
class TestEventsNotifications(TestCase):
"""Test Event Notifications"""
def setUp(self) -> None:
@ -133,15 +131,3 @@ class TestEventsNotifications(APITestCase):
Notification.objects.all().delete()
Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(Notification.objects.first().body, "foo")
def test_api_mark_all_seen(self):
"""Test mark_all_seen"""
self.client.force_login(self.user)
Notification.objects.create(
severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
)
response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
self.assertEqual(response.status_code, 204)
self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())

View File

@ -74,9 +74,9 @@ class FlowPlan:
def redirect(self, destination: str):
"""Insert a redirect stage as next stage"""
from authentik.flows.stage import RedirectStage
from authentik.flows.stage import RedirectStageChallengeView
self.insert_stage(in_memory_stage(RedirectStage, destination=destination))
self.insert_stage(in_memory_stage(RedirectStageChallengeView, destination=destination))
def next(self, http_request: HttpRequest | None) -> FlowStageBinding | None:
"""Return next pending stage from the bottom of the list"""

View File

@ -30,6 +30,7 @@ from authentik.lib.avatars import DEFAULT_AVATAR, get_avatar
from authentik.lib.utils.reflection import class_to_path
if TYPE_CHECKING:
from authentik.flows.models import Stage
from authentik.flows.views.executor import FlowExecutorView
PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier"
@ -40,20 +41,21 @@ HIST_FLOWS_STAGE_TIME = Histogram(
)
class StageView(View):
class StageView[TStage: "Stage"](View):
"""Abstract Stage"""
executor: "FlowExecutorView"
current_stage: TStage
request: HttpRequest = None
logger: BoundLogger
def __init__(self, executor: "FlowExecutorView", **kwargs):
def __init__(self, executor: "FlowExecutorView", current_stage: TStage | None = None, **kwargs):
self.executor = executor
current_stage = getattr(self.executor, "current_stage", None)
self.current_stage = current_stage or executor.current_stage
self.logger = get_logger().bind(
stage=getattr(current_stage, "name", None),
stage=getattr(self.current_stage, "name", None),
stage_view=class_to_path(type(self)),
)
super().__init__(**kwargs)
@ -80,7 +82,7 @@ class StageView(View):
"""Cleanup session"""
class ChallengeStageView(StageView):
class ChallengeStageView[TStage: "Stage"](StageView[TStage]):
"""Stage view which response with a challenge"""
response_class = ChallengeResponse
@ -253,12 +255,12 @@ class AccessDeniedChallengeView(ChallengeStageView):
return self.executor.cancel()
class RedirectStage(ChallengeStageView):
class RedirectStageChallengeView(ChallengeStageView):
"""Redirect to any URL"""
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
destination = getattr(
self.executor.current_stage, "destination", reverse("authentik_core:root-redirect")
self.current_stage, "destination", reverse("authentik_core:root-redirect")
)
return RedirectChallenge(
data={

View File

@ -2,6 +2,7 @@
import re
import socket
from collections.abc import Iterable
from ipaddress import ip_address, ip_network
from textwrap import indent
from types import CodeType
@ -27,12 +28,6 @@ from authentik.stages.authenticator import devices_for_user
LOGGER = get_logger()
ARG_SANITIZE = re.compile(r"[:.-]")
def sanitize_arg(arg_name: str) -> str:
return re.sub(ARG_SANITIZE, "_", arg_name)
class BaseEvaluator:
"""Validate and evaluate python-based expressions"""
@ -182,9 +177,9 @@ class BaseEvaluator:
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
return proc.profiling_wrapper()
def wrap_expression(self, expression: str) -> str:
def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
"""Wrap expression in a function, call it, and save the result as `result`"""
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
handler_signature = ",".join(params)
full_expression = ""
full_expression += f"def handler({handler_signature}):\n"
full_expression += indent(expression, " ")
@ -193,8 +188,8 @@ class BaseEvaluator:
def compile(self, expression: str) -> CodeType:
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
expression = self.wrap_expression(expression)
return compile(expression, self._filename, "exec")
param_keys = self._context.keys()
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
def evaluate(self, expression_source: str) -> Any:
"""Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
@ -210,7 +205,7 @@ class BaseEvaluator:
self.handle_error(exc, expression_source)
raise exc
try:
_locals = {sanitize_arg(x): y for x, y in self._context.items()}
_locals = self._context
# Yes this is an exec, yes it is potentially bad. Since we limit what variables are
# available here, and these policies can only be edited by admins, this is a risk
# we're willing to take.

View File

@ -30,11 +30,6 @@ class TestHTTP(TestCase):
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
def test_forward_for_invalid(self):
"""Test invalid forward for"""
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip)
def test_fake_outpost(self):
"""Test faked IP which is overridden by an outpost"""
token = Token.objects.create(
@ -58,17 +53,6 @@ class TestHTTP(TestCase):
},
)
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
# Invalid, not a real IP
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
self.user.save()
request = self.factory.get(
"/",
**{
ClientIPMiddleware.outpost_remote_ip_header: "foobar",
ClientIPMiddleware.outpost_token_header: token.key,
},
)
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
# Valid
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
self.user.save()

View File

@ -21,14 +21,7 @@ class DebugSession(Session):
def send(self, req: PreparedRequest, *args, **kwargs):
request_id = str(uuid4())
LOGGER.debug(
"HTTP request sent",
uid=request_id,
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
)
LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
resp = super().send(req, *args, **kwargs)
LOGGER.debug(
"HTTP response received",

View File

@ -108,7 +108,7 @@ class EventMatcherPolicy(Policy):
result=result,
)
matches.append(result)
passing = all(x.passing for x in matches)
passing = any(x.passing for x in matches)
messages = chain(*[x.messages for x in matches])
result = PolicyResult(passing, *messages)
result.source_results = matches

View File

@ -77,24 +77,11 @@ class TestEventMatcherPolicy(TestCase):
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.5", app="foo"
client_ip="1.2.3.5", app="bar"
)
response = policy.passes(request)
self.assertFalse(response.passing)
def test_multiple(self):
"""Test multiple"""
event = Event.new(EventAction.LOGIN)
event.app = "foo"
event.client_ip = "1.2.3.4"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.4", app="foo"
)
response = policy.passes(request)
self.assertTrue(response.passing)
def test_invalid(self):
"""Test passing event"""
request = PolicyRequest(get_anonymous_user())

View File

@ -4,13 +4,13 @@ from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db import migrations
from django.contrib.auth.management import create_permissions
def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from guardian.shortcuts import assign_perm
from authentik.core.models import User
from django.apps import apps as real_apps
from django.contrib.auth.management import create_permissions
from guardian.shortcuts import UserObjectPermission
db_alias = schema_editor.connection.alias
@ -20,25 +20,14 @@ def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
Permission = apps.get_model("auth", "Permission")
UserObjectPermission = apps.get_model("guardian", "UserObjectPermission")
ContentType = apps.get_model("contenttypes", "ContentType")
new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory")
ct = ContentType.objects.using(db_alias).get(
app_label="authentik_providers_ldap",
model="ldapprovider",
)
for provider in LDAPProvider.objects.using(db_alias).all():
if not provider.search_group:
continue
for user in provider.search_group.users.using(db_alias).all():
UserObjectPermission.objects.using(db_alias).create(
user=user,
permission=new_prem,
object_pk=provider.pk,
content_type=ct,
for user_pk in (
provider.search_group.users.using(db_alias).all().values_list("pk", flat=True)
):
# We need the correct user model instance to assign the permission
assign_perm(
"search_full_directory", User.objects.using(db_alias).get(pk=user_pk), provider
)
@ -46,7 +35,6 @@ class Migration(migrations.Migration):
dependencies = [
("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
("guardian", "0002_generic_permissions_index"),
]
operations = [

View File

@ -29,6 +29,7 @@ class TesOAuth2Introspection(OAuthTestCase):
self.app = Application.objects.create(
name=generate_id(), slug=generate_id(), provider=self.provider
)
self.app.save()
self.user = create_test_admin_user()
self.auth = b64encode(
f"{self.provider.client_id}:{self.provider.client_secret}".encode()
@ -113,41 +114,6 @@ class TesOAuth2Introspection(OAuthTestCase):
},
)
def test_introspect_invalid_provider(self):
"""Test introspection (mismatched provider and token)"""
provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris="",
signing_key=create_test_cert(),
)
auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
token: AccessToken = AccessToken.objects.create(
provider=self.provider,
user=self.user,
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),
HTTP_AUTHORIZATION=f"Basic {auth}",
data={"token": token.token},
)
self.assertEqual(res.status_code, 200)
self.assertJSONEqual(
res.content.decode(),
{
"active": False,
},
)
def test_introspect_invalid_auth(self):
"""Test introspect (invalid auth)"""
res = self.client.post(

View File

@ -46,10 +46,10 @@ class TokenIntrospectionParams:
if not provider:
raise TokenIntrospectionError
access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first()
access_token = AccessToken.objects.filter(token=raw_token).first()
if access_token:
return TokenIntrospectionParams(access_token, provider)
refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first()
refresh_token = RefreshToken.objects.filter(token=raw_token).first()
if refresh_token:
return TokenIntrospectionParams(refresh_token, provider)
LOGGER.debug("Token does not exist", token=raw_token)

View File

@ -433,20 +433,20 @@ class TokenParams:
app = Application.objects.filter(provider=self.provider).first()
if not app or not app.provider:
raise TokenError("invalid_grant")
with audit_ignore():
self.user, _ = User.objects.update_or_create(
# trim username to ensure the entire username is max 150 chars
# (22 chars being the length of the "template")
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
defaults={
"last_login": timezone.now(),
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
self.user, _ = User.objects.update_or_create(
# trim username to ensure the entire username is max 150 chars
# (22 chars being the length of the "template")
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
defaults={
"attributes": {
USER_ATTRIBUTE_GENERATED: True,
},
)
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
self.user.save()
"last_login": timezone.now(),
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.__check_policy_access(app, request)
Event.new(
@ -470,6 +470,9 @@ class TokenParams:
self.user, created = User.objects.update_or_create(
username=f"{self.provider.name}-{token.get('sub')}",
defaults={
"attributes": {
USER_ATTRIBUTE_GENERATED: True,
},
"last_login": timezone.now(),
"name": (
f"Autogenerated user from application {app.name} (client credentials JWT)"
@ -478,8 +481,6 @@ class TokenParams:
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
self.user.save()
exp = token.get("exp")
if created and exp:
self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp

View File

@ -28,7 +28,7 @@ class ProxyDockerController(DockerController):
labels = super()._get_labels()
labels["traefik.enable"] = "true"
labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
f"({' || '.join([f'Host({host})' for host in hosts])})"
f"({' || '.join([f'Host(`{host}`)' for host in hosts])})"
f" && PathPrefix(`/outpost.goauthentik.io`)"
)
labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"

View File

@ -50,7 +50,6 @@ class AssertionProcessor:
_issue_instant: str
_assertion_id: str
_response_id: str
_valid_not_before: str
_session_not_on_or_after: str
@ -63,7 +62,6 @@ class AssertionProcessor:
self._issue_instant = get_time_string()
self._assertion_id = get_random_id()
self._response_id = get_random_id()
self._valid_not_before = get_time_string(
timedelta_from_string(self.provider.assertion_valid_not_before)
@ -132,9 +130,7 @@ class AssertionProcessor:
"""Generate AuthnStatement with AuthnContext and ContextClassRef Elements."""
auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement")
auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before
auth_n_statement.attrib["SessionIndex"] = sha256(
self.http_request.session.session_key.encode("ascii")
).hexdigest()
auth_n_statement.attrib["SessionIndex"] = self._assertion_id
auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after
auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext")
@ -289,7 +285,7 @@ class AssertionProcessor:
response.attrib["Version"] = "2.0"
response.attrib["IssueInstant"] = self._issue_instant
response.attrib["Destination"] = self.provider.acs_url
response.attrib["ID"] = self._response_id
response.attrib["ID"] = get_random_id()
if self.auth_n_request.id:
response.attrib["InResponseTo"] = self.auth_n_request.id
@ -312,7 +308,7 @@ class AssertionProcessor:
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + element.attrib["ID"],
uri="#" + self._assertion_id,
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)

View File

@ -180,10 +180,6 @@ class TestAuthNRequest(TestCase):
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Ensure both response and assertion ID are in the response twice (once as ID attribute,
# once as ds:Reference URI)
self.assertEqual(response.count(response_proc._assertion_id), 2)
self.assertEqual(response.count(response_proc._response_id), 2)
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)

View File

@ -54,11 +54,7 @@ class TestServiceProviderMetadataParser(TestCase):
request = self.factory.get("/")
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
schema = etree.XMLSchema(
etree.parse(
source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()
) # nosec
)
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata))
def test_schema_want_authn_requests_signed(self):

View File

@ -47,9 +47,7 @@ class TestSchema(TestCase):
metadata = lxml_from_string(request)
schema = etree.XMLSchema(
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata))
def test_response_schema(self):
@ -70,7 +68,5 @@ class TestSchema(TestCase):
metadata = lxml_from_string(response)
schema = etree.XMLSchema(
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata))

View File

@ -2,10 +2,9 @@
from itertools import batched
from django.db import transaction
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp
from pydanticscim.responses import PatchOp, PatchOperation
from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager
@ -20,7 +19,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
SCIMRequestException,
)
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import (
SCIMMapping,
@ -105,47 +104,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
provider=self.provider, group=group, scim_id=scim_id
)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(connection, users)
self._patch_add_users(group, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group"""
scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id
try:
if self._config.patch.supported:
return self._update_patch(group, scim_group, connection)
return self._update_put(group, scim_group, connection)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
def _update_patch(
self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup
):
"""Update a group via PATCH request"""
# Patch group's attributes instead of replacing it and re-adding users if we can
self._request(
"PATCH",
f"/Groups/{connection.scim_id}",
json=PatchRequest(
Operations=[
PatchOperation(
op=PatchOp.replace,
path=None,
value=scim_group.model_dump(mode="json", exclude_unset=True),
)
]
).model_dump(
mode="json",
exclude_unset=True,
exclude_none=True,
),
)
return self.patch_compare_users(group)
def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup):
"""Update a group via PUT request"""
try:
self._request(
"PUT",
@ -155,25 +120,33 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True,
),
)
return self.patch_compare_users(group)
users = list(group.users.order_by("id").values_list("id", flat=True))
return self._patch_add_users(group, users)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
except (SCIMRequestException, ObjectExistsSyncException):
# Some providers don't support PUT on groups, so this is mainly a fix for the initial
# sync, send patch add requests for all the users the group currently has
return self._update_patch(group, scim_group, connection)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
# Also update the group name
return self._patch(
scim_group.id,
PatchOperation(
op=PatchOp.replace,
path="displayName",
value=scim_group.displayName,
),
)
def update_group(self, group: Group, action: Direction, users_set: set[int]):
"""Update a group, either using PUT to replace it or PATCH if supported"""
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
if self._config.patch.supported:
if action == Direction.add:
return self._patch_add_users(scim_group, users_set)
return self._patch_add_users(group, users_set)
if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set)
return self._patch_remove_users(group, users_set)
try:
return self.write(group)
except SCIMRequestException as exc:
@ -181,24 +154,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
# Assume that provider does not support PUT and also doesn't support
# ServiceProviderConfig, so try PATCH as a fallback
if action == Direction.add:
return self._patch_add_users(scim_group, users_set)
return self._patch_add_users(group, users_set)
if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set)
return self._patch_remove_users(group, users_set)
raise exc
def _patch_chunked(
def _patch(
self,
group_id: str,
*ops: PatchOperation,
):
"""Helper function that chunks patch requests based on the maxOperations attribute.
This is not strictly according to specs but there's nothing in the schema that allows the
us to know what the maximum patch operations per request should be."""
chunk_size = self._config.bulk.maxOperations
if chunk_size < 1:
chunk_size = len(ops)
if len(ops) < 1:
return
for chunk in batched(ops, chunk_size):
req = PatchRequest(Operations=list(chunk))
self._request(
@ -209,70 +177,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
),
)
@transaction.atomic
def patch_compare_users(self, group: Group):
"""Compare users with a SCIM group and add/remove any differences"""
# Get scim group first
def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
# Get a list of all users in the authentik group
raw_users_should = list(group.users.order_by("id").values_list("id", flat=True))
# Lookup the SCIM IDs of the users
users_should: list[str] = list(
SCIMProviderUser.objects.filter(
user__pk__in=raw_users_should, provider=self.provider
).values_list("scim_id", flat=True)
)
if len(raw_users_should) != len(users_should):
self.logger.warning(
"User count mismatch, not all users in the group are synced to SCIM yet.",
group=group,
)
# Get current group status
current_group = SCIMGroupSchema.model_validate(
self._request("GET", f"/Groups/{scim_group.scim_id}")
)
users_to_add = []
users_to_remove = []
# Check users currently in group and if they shouldn't be in the group and remove them
for user in current_group.members or []:
if user.value not in users_should:
users_to_remove.append(user.value)
# Check users that should be in the group and add them
for user in users_should:
if len([x for x in current_group.members if x.value == user]) < 1:
users_to_add.append(user)
# Only send request if we need to make changes
if len(users_to_add) < 1 and len(users_to_remove) < 1:
return
return self._patch_chunked(
scim_group.scim_id,
*[
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x}],
)
for x in users_to_add
],
*[
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x}],
)
for x in users_to_remove
],
)
def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -280,7 +194,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch_chunked(
self._patch(
scim_group.scim_id,
*[
PatchOperation(
@ -292,10 +206,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
],
)
def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
def _patch_remove_users(self, group: Group, users_set: set[int]):
"""Remove users in users_set from group"""
if len(users_set) < 1:
return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -303,7 +223,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch_chunked(
self._patch(
scim_group.scim_id,
*[
PatchOperation(

View File

@ -2,7 +2,6 @@
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk as BaseBulk
@ -69,12 +68,6 @@ class PatchRequest(BasePatchRequest):
schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",)
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
path: str | None
class SCIMError(BaseSCIMError):
"""SCIM error with optional status code"""

View File

@ -252,118 +252,3 @@ class SCIMMembershipTests(TestCase):
],
},
)
def test_member_add_save(self):
"""Test member add + save"""
config = ServiceProviderConfiguration.default()
config.patch.supported = True
user_scim_id = generate_id()
group_scim_id = generate_id()
uid = generate_id()
group = Group.objects.create(
name=uid,
)
user = User.objects.create(username=generate_id())
# Test initial sync of group creation
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.post(
"https://localhost/Users",
json={
"id": user_scim_id,
},
)
mocker.post(
"https://localhost/Groups",
json={
"id": group_scim_id,
},
)
self.configure()
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual(
mocker.request_history[3].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [],
"active": True,
"externalId": user.uid,
"name": {"familyName": " ", "formatted": " ", "givenName": ""},
"displayName": "",
"userName": user.username,
},
)
self.assertJSONEqual(
mocker.request_history[5].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
)
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.get(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
mocker.patch(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
group.users.add(user)
group.save()
self.assertEqual(mocker.call_count, 5)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "PATCH")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "PATCH")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
"path": "members",
"value": [{"value": user_scim_id}],
}
],
},
)
self.assertJSONEqual(
mocker.request_history[3].body,
{
"Operations": [
{
"op": "replace",
"value": {
"id": group_scim_id,
"displayName": group.name,
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
},
}
]
},
)

View File

@ -87,11 +87,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
def _get_startup_tasks_default_tenant() -> list[Callable]:
"""Get all tasks to be run on startup for the default tenant"""
from authentik.outposts.tasks import outpost_connection_discovery
return [
outpost_connection_discovery,
]
return []
def _get_startup_tasks_all_tenants() -> list[Callable]:

View File

@ -2,7 +2,6 @@
from collections.abc import Callable
from hashlib import sha512
from ipaddress import ip_address
from time import perf_counter, time
from typing import Any
@ -175,7 +174,6 @@ class ClientIPMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
self.logger = get_logger().bind()
def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str:
"""Attempt to get the client's IP by checking common HTTP Headers.
@ -187,16 +185,11 @@ class ClientIPMiddleware:
"HTTP_X_FORWARDED_FOR",
"REMOTE_ADDR",
)
try:
for _header in headers:
if _header in meta:
ips: list[str] = meta.get(_header).split(",")
# Ensure the IP parses as a valid IP
return str(ip_address(ips[0].strip()))
return self.default_ip
except ValueError as exc:
self.logger.debug("Invalid remote IP", exc=exc)
return self.default_ip
for _header in headers:
if _header in meta:
ips: list[str] = meta.get(_header).split(",")
return ips[0].strip()
return self.default_ip
# FIXME: this should probably not be in `root` but rather in a middleware in `outposts`
# but for now it's fine
@ -233,11 +226,7 @@ class ClientIPMiddleware:
Scope.get_isolation_scope().set_user(user)
# Set the outpost service account on the request
setattr(request, self.request_attr_outpost_user, user)
try:
return str(ip_address(delegated_ip))
except ValueError as exc:
self.logger.debug("Invalid remote IP from Outpost", exc=exc)
return None
return delegated_ip
def _get_client_ip(self, request: HttpRequest | None) -> str:
"""Attempt to get the client's IP by checking common HTTP Headers.

View File

@ -1,7 +1,6 @@
"""authentik storage backends"""
import os
from urllib.parse import parse_qsl, urlsplit
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
@ -111,34 +110,3 @@ class S3Storage(BaseS3Storage):
if self.querystring_auth:
return url
return self._strip_signing_parameters(url)
def _strip_signing_parameters(self, url):
# Boto3 does not currently support generating URLs that are unsigned. Instead
# we take the signed URLs and strip any querystring params related to signing
# and expiration.
# Note that this may end up with URLs that are still invalid, especially if
# params are passed in that only work with signed URLs, e.g. response header
# params.
# The code attempts to strip all query parameters that match names of known
# parameters from v2 and v4 signatures, regardless of the actual signature
# version used.
split_url = urlsplit(url)
qs = parse_qsl(split_url.query, keep_blank_values=True)
blacklist = {
"x-amz-algorithm",
"x-amz-credential",
"x-amz-date",
"x-amz-expires",
"x-amz-signedheaders",
"x-amz-signature",
"x-amz-security-token",
"awsaccesskeyid",
"expires",
"signature",
}
filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist)
# Note: Parameters that did not have a value in the original query string will
# have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar=
joined_qs = ("=".join(keyval) for keyval in filtered_qs)
split_url = split_url._replace(query="&".join(joined_qs))
return split_url.geturl()

View File

@ -3,7 +3,6 @@
from typing import Any
from django.core.cache import cache
from django.utils.translation import gettext_lazy as _
from drf_spectacular.utils import extend_schema, inline_serializer
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
@ -40,8 +39,9 @@ class LDAPSourceSerializer(SourceSerializer):
"""Get cached source connectivity"""
return cache.get(CACHE_KEY_STATUS + source.slug, None)
def validate_sync_users_password(self, sync_users_password: bool) -> bool:
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Check that only a single source has password_sync on"""
sync_users_password = attrs.get("sync_users_password", True)
if sync_users_password:
sources = LDAPSource.objects.filter(sync_users_password=True)
if self.instance:
@ -49,31 +49,11 @@ class LDAPSourceSerializer(SourceSerializer):
if sources.exists():
raise ValidationError(
{
"sync_users_password": _(
"sync_users_password": (
"Only a single LDAP Source with password synchronization is allowed"
)
}
)
return sync_users_password
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Validate property mappings with sync_ flags"""
types = ["user", "group"]
for type in types:
toggle_value = attrs.get(f"sync_{type}s", False)
mappings_field = f"{type}_property_mappings"
mappings_value = attrs.get(mappings_field, [])
if toggle_value and len(mappings_value) == 0:
raise ValidationError(
{
mappings_field: _(
(
"When 'Sync {type}s' is enabled, '{type}s property "
"mappings' cannot be empty."
).format(type=type)
)
}
)
return super().validate(attrs)
class Meta:
@ -186,12 +166,11 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
for sync_class in SYNC_CLASSES:
class_name = sync_class.name()
all_objects.setdefault(class_name, [])
for page in sync_class(source).get_objects(size_limit=10):
for obj in page:
obj: dict
obj.pop("raw_attributes", None)
obj.pop("raw_dn", None)
all_objects[class_name].append(obj)
for obj in sync_class(source).get_objects(size_limit=10):
obj: dict
obj.pop("raw_attributes", None)
obj.pop("raw_dn", None)
all_objects[class_name].append(obj)
return Response(data=all_objects)

View File

@ -26,16 +26,17 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
"""Ensure that source is synced on save (if enabled)"""
if not instance.enabled:
return
ldap_connectivity_check.delay(instance.pk)
# Don't sync sources when they don't have any property mappings. This will only happen if:
# - the user forgets to set them or
# - the source is newly created, this is the first save event
# and the mappings are created with an m2m event
if instance.sync_users and not instance.user_property_mappings.exists():
return
if instance.sync_groups and not instance.group_property_mappings.exists():
if (
not instance.user_property_mappings.exists()
or not instance.group_property_mappings.exists()
):
return
ldap_sync_single.delay(instance.pk)
ldap_connectivity_check.delay(instance.pk)
@receiver(password_validate)

View File

@ -38,11 +38,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter,
search_scope=SUBTREE,
attributes=[
ALL_ATTRIBUTES,
ALL_OPERATIONAL_ATTRIBUTES,
self._source.object_uniqueness_field,
],
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
**kwargs,
)
@ -57,9 +53,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
continue
attributes = group.get("attributes", {})
group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
if not attributes.get(self._source.object_uniqueness_field):
if self._source.object_uniqueness_field not in attributes:
self.message(
f"Uniqueness field not found/not set in attributes: '{group_dn}'",
f"Cannot find uniqueness field in attributes: '{group_dn}'",
attributes=attributes.keys(),
dn=group_dn,
)

View File

@ -40,11 +40,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
search_base=self.base_dn_users,
search_filter=self._source.user_object_filter,
search_scope=SUBTREE,
attributes=[
ALL_ATTRIBUTES,
ALL_OPERATIONAL_ATTRIBUTES,
self._source.object_uniqueness_field,
],
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
**kwargs,
)
@ -59,9 +55,9 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
continue
attributes = user.get("attributes", {})
user_dn = flatten(user.get("entryDN", user.get("dn")))
if not attributes.get(self._source.object_uniqueness_field):
if self._source.object_uniqueness_field not in attributes:
self.message(
f"Uniqueness field not found/not set in attributes: '{user_dn}'",
f"Cannot find uniqueness field in attributes: '{user_dn}'",
attributes=attributes.keys(),
dn=user_dn,
)

View File

@ -78,9 +78,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
# /useraccountcontrol-manipulate-account-properties
uac_bit = attributes.get("userAccountControl", 512)
uac = UserAccountControl(uac_bit)
is_active = (
UserAccountControl.ACCOUNTDISABLE not in uac and UserAccountControl.LOCKOUT not in uac
)
is_active = UserAccountControl.ACCOUNTDISABLE not in uac
if is_active != user.is_active:
user.is_active = is_active
user.save()

View File

@ -50,35 +50,3 @@ class LDAPAPITests(APITestCase):
}
)
self.assertFalse(serializer.is_valid())
def test_sync_users_mapping_empty(self):
"""Check that when sync_users is enabled, property mappings must be set"""
serializer = LDAPSourceSerializer(
data={
"name": "foo",
"slug": " foo",
"server_uri": "ldaps://1.2.3.4",
"bind_cn": "",
"bind_password": LDAP_PASSWORD,
"base_dn": "dc=foo",
"sync_users": True,
"user_property_mappings": [],
}
)
self.assertFalse(serializer.is_valid())
def test_sync_groups_mapping_empty(self):
"""Check that when sync_groups is enabled, property mappings must be set"""
serializer = LDAPSourceSerializer(
data={
"name": "foo",
"slug": " foo",
"server_uri": "ldaps://1.2.3.4",
"bind_cn": "",
"bind_password": LDAP_PASSWORD,
"base_dn": "dc=foo",
"sync_groups": True,
"group_property_mappings": [],
}
)
self.assertFalse(serializer.is_valid())

View File

@ -30,9 +30,7 @@ class TestMetadataProcessor(TestCase):
xml = MetadataProcessor(self.source, request).build_entity_descriptor()
metadata = lxml_from_string(xml)
schema = etree.XMLSchema(
etree.parse("schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()) # nosec
)
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata))
def test_metadata_consistent(self):

View File

@ -32,7 +32,7 @@ class AuthenticatorDuoChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-authenticator-duo")
class AuthenticatorDuoStageView(ChallengeStageView):
class AuthenticatorDuoStageView(ChallengeStageView[AuthenticatorDuoStage]):
"""Duo stage"""
response_class = AuthenticatorDuoChallengeResponse
@ -40,9 +40,8 @@ class AuthenticatorDuoStageView(ChallengeStageView):
def duo_enroll(self):
"""Enroll User with Duo API and save results"""
user = self.get_pending_user()
stage: AuthenticatorDuoStage = self.executor.current_stage
try:
enroll = stage.auth_client().enroll(user.username)
enroll = self.current_stage.auth_client().enroll(user.username)
except RuntimeError as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
@ -54,7 +53,6 @@ class AuthenticatorDuoStageView(ChallengeStageView):
return enroll
def get_challenge(self, *args, **kwargs) -> Challenge:
stage: AuthenticatorDuoStage = self.executor.current_stage
if SESSION_KEY_DUO_ENROLL not in self.request.session:
self.duo_enroll()
enroll = self.request.session[SESSION_KEY_DUO_ENROLL]
@ -62,15 +60,14 @@ class AuthenticatorDuoStageView(ChallengeStageView):
data={
"activation_barcode": enroll["activation_barcode"],
"activation_code": enroll["activation_code"],
"stage_uuid": str(stage.stage_uuid),
"stage_uuid": str(self.current_stage.stage_uuid),
}
)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
# Duo Challenge has already been validated
stage: AuthenticatorDuoStage = self.executor.current_stage
enroll = self.request.session.get(SESSION_KEY_DUO_ENROLL)
enroll_status = stage.auth_client().enroll_status(
enroll_status = self.current_stage.auth_client().enroll_status(
enroll["user_id"], enroll["activation_code"]
)
if enroll_status != "success":
@ -82,7 +79,7 @@ class AuthenticatorDuoStageView(ChallengeStageView):
name="Duo Authenticator",
user=self.get_pending_user(),
duo_user_id=enroll["user_id"],
stage=stage,
stage=self.current_stage,
last_t=now(),
)
else:

View File

@ -57,21 +57,20 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse):
return super().validate(attrs)
class AuthenticatorSMSStageView(ChallengeStageView):
class AuthenticatorSMSStageView(ChallengeStageView[AuthenticatorSMSStage]):
"""OTP sms Setup stage"""
response_class = AuthenticatorSMSChallengeResponse
def validate_and_send(self, phone_number: str):
"""Validate phone number and send message"""
stage: AuthenticatorSMSStage = self.executor.current_stage
hashed_number = hash_phone_number(phone_number)
query = Q(phone_number=hashed_number) | Q(phone_number=phone_number)
if SMSDevice.objects.filter(query, stage=stage.pk).exists():
if SMSDevice.objects.filter(query, stage=self.current_stage.pk).exists():
raise ValidationError(_("Invalid phone number"))
# No code yet, but we have a phone number, so send a verification message
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
stage.send(device.token, device)
self.current_stage.send(device.token, device)
def _has_phone_number(self) -> str | None:
context = self.executor.plan.context
@ -101,10 +100,10 @@ class AuthenticatorSMSStageView(ChallengeStageView):
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
user = self.get_pending_user()
stage: AuthenticatorSMSStage = self.executor.current_stage
if SESSION_KEY_SMS_DEVICE not in self.request.session:
device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device")
device = SMSDevice(
user=user, confirmed=False, stage=self.current_stage, name="SMS Device"
)
device.generate_token(commit=False)
self.request.session[SESSION_KEY_SMS_DEVICE] = device
if phone_number := self._has_phone_number():
@ -130,8 +129,7 @@ class AuthenticatorSMSStageView(ChallengeStageView):
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
if not device.confirmed:
return self.challenge_invalid(response)
stage: AuthenticatorSMSStage = self.executor.current_stage
if stage.verify_only:
if self.current_stage.verify_only:
self.logger.debug("Hashing number on device")
device.set_hashed_number()
device.save()

View File

@ -29,7 +29,7 @@ class AuthenticatorStaticChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-authenticator-static")
class AuthenticatorStaticStageView(ChallengeStageView):
class AuthenticatorStaticStageView(ChallengeStageView[AuthenticatorStaticStage]):
"""Static OTP Setup stage"""
response_class = AuthenticatorStaticChallengeResponse
@ -48,14 +48,14 @@ class AuthenticatorStaticStageView(ChallengeStageView):
self.logger.debug("No pending user, continuing")
return self.executor.stage_ok()
stage: AuthenticatorStaticStage = self.executor.current_stage
if SESSION_STATIC_DEVICE not in self.request.session:
device = StaticDevice(user=user, confirmed=False, name="Static Token")
tokens = []
for _ in range(0, stage.token_count):
for _ in range(0, self.current_stage.token_count):
tokens.append(
StaticToken(device=device, token=generate_id(length=stage.token_length))
StaticToken(
device=device, token=generate_id(length=self.current_stage.token_length)
)
)
self.request.session[SESSION_STATIC_DEVICE] = device
self.request.session[SESSION_STATIC_TOKENS] = tokens

View File

@ -45,7 +45,7 @@ class AuthenticatorTOTPChallengeResponse(ChallengeResponse):
return code
class AuthenticatorTOTPStageView(ChallengeStageView):
class AuthenticatorTOTPStageView(ChallengeStageView[AuthenticatorTOTPStage]):
"""OTP totp Setup stage"""
response_class = AuthenticatorTOTPChallengeResponse
@ -71,11 +71,12 @@ class AuthenticatorTOTPStageView(ChallengeStageView):
self.logger.debug("No pending user, continuing")
return self.executor.stage_ok()
stage: AuthenticatorTOTPStage = self.executor.current_stage
if SESSION_TOTP_DEVICE not in self.request.session:
device = TOTPDevice(
user=user, confirmed=False, digits=stage.digits, name="TOTP Authenticator"
user=user,
confirmed=False,
digits=self.current_stage.digits,
name="TOTP Authenticator",
)
self.request.session[SESSION_TOTP_DEVICE] = device

View File

@ -151,7 +151,7 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
return attrs
class AuthenticatorValidateStageView(ChallengeStageView):
class AuthenticatorValidateStageView(ChallengeStageView[AuthenticatorValidateStage]):
"""Authenticator Validation"""
response_class = AuthenticatorValidationChallengeResponse
@ -177,16 +177,14 @@ class AuthenticatorValidateStageView(ChallengeStageView):
# since their challenges are device-independent
seen_classes = []
stage: AuthenticatorValidateStage = self.executor.current_stage
threshold = timedelta_from_string(stage.last_auth_threshold)
threshold = timedelta_from_string(self.current_stage.last_auth_threshold)
allowed_devices = []
has_webauthn_filters_set = stage.webauthn_allowed_device_types.exists()
has_webauthn_filters_set = self.current_stage.webauthn_allowed_device_types.exists()
for device in user_devices:
device_class = device.__class__.__name__.lower().replace("device", "")
if device_class not in stage.device_classes:
if device_class not in self.current_stage.device_classes:
self.logger.debug("device class not allowed", device_class=device_class)
continue
if isinstance(device, SMSDevice) and device.is_hashed:
@ -199,7 +197,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
and device.device_type
and has_webauthn_filters_set
):
if not stage.webauthn_allowed_device_types.filter(
if not self.current_stage.webauthn_allowed_device_types.filter(
pk=device.device_type.pk
).exists():
self.logger.debug(
@ -216,7 +214,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
data={
"device_class": device_class,
"device_uid": device.pk,
"challenge": get_challenge_for_device(self.request, stage, device),
"challenge": get_challenge_for_device(self.request, self.current_stage, device),
}
)
challenge.is_valid()
@ -235,7 +233,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device_uid": -1,
"challenge": get_webauthn_challenge_without_user(
self.request,
self.executor.current_stage,
self.current_stage,
),
}
)
@ -246,7 +244,6 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"""Check if a user is set, and check if the user has any devices
if not, we can skip this entire stage"""
user = self.get_pending_user()
stage: AuthenticatorValidateStage = self.executor.current_stage
if user and not user.is_anonymous:
try:
challenges = self.get_device_challenges()
@ -257,7 +254,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
self.logger.debug("Refusing passwordless flow in non-authentication flow")
return self.executor.stage_ok()
# Passwordless auth, with just webauthn
if DeviceClasses.WEBAUTHN in stage.device_classes:
if DeviceClasses.WEBAUTHN in self.current_stage.device_classes:
self.logger.debug("Flow without user, getting generic webauthn challenge")
challenges = self.get_webauthn_challenge_without_user()
else:
@ -267,13 +264,13 @@ class AuthenticatorValidateStageView(ChallengeStageView):
# No allowed devices
if len(challenges) < 1:
if stage.not_configured_action == NotConfiguredAction.SKIP:
if self.current_stage.not_configured_action == NotConfiguredAction.SKIP:
self.logger.debug("Authenticator not configured, skipping stage")
return self.executor.stage_ok()
if stage.not_configured_action == NotConfiguredAction.DENY:
if self.current_stage.not_configured_action == NotConfiguredAction.DENY:
self.logger.debug("Authenticator not configured, denying")
return self.executor.stage_invalid(_("No (allowed) MFA authenticator configured."))
if stage.not_configured_action == NotConfiguredAction.CONFIGURE:
if self.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE:
self.logger.debug("Authenticator not configured, forcing configure")
return self.prepare_stages(user)
return super().get(request, *args, **kwargs)
@ -282,8 +279,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"""Check how the user can configure themselves. If no stages are set, return an error.
If a single stage is set, insert that stage directly. If multiple are selected, include
them in the challenge."""
stage: AuthenticatorValidateStage = self.executor.current_stage
if not stage.configuration_stages.exists():
if not self.current_stage.configuration_stages.exists():
Event.new(
EventAction.CONFIGURATION_ERROR,
message=(
@ -293,15 +289,19 @@ class AuthenticatorValidateStageView(ChallengeStageView):
stage=self,
).from_http(self.request).set_user(user).save()
return self.executor.stage_invalid()
if stage.configuration_stages.count() == 1:
next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk)
if self.current_stage.configuration_stages.count() == 1:
next_stage = Stage.objects.get_subclass(
pk=self.current_stage.configuration_stages.first().pk
)
self.logger.debug("Single stage configured, auto-selecting", stage=next_stage)
self.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = next_stage
# Because that normal execution only happens on post, we directly inject it here and
# return it
self.executor.plan.insert_stage(next_stage)
return self.executor.stage_ok()
stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses()
stages = Stage.objects.filter(
pk__in=self.current_stage.configuration_stages.all()
).select_subclasses()
self.executor.plan.context[PLAN_CONTEXT_STAGES] = stages
return super().get(self.request, *args, **kwargs)
@ -309,7 +309,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
res = super().post(request, *args, **kwargs)
if (
PLAN_CONTEXT_SELECTED_STAGE in self.executor.plan.context
and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
and self.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
):
self.logger.debug("Got selected stage in context, running that")
stage_pk = self.executor.plan.context.get(PLAN_CONTEXT_SELECTED_STAGE)
@ -351,7 +351,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
def cookie_jwt_key(self) -> str:
"""Signing key for MFA Cookie for this stage"""
return sha256(
f"{get_unique_identifier()}:{self.executor.current_stage.pk.hex}".encode("ascii")
f"{get_unique_identifier()}:{self.current_stage.pk.hex}".encode("ascii")
).hexdigest()
def check_mfa_cookie(self, allowed_devices: list[Device]):
@ -362,12 +362,11 @@ class AuthenticatorValidateStageView(ChallengeStageView):
correct user and with an allowed class"""
if COOKIE_NAME_MFA not in self.request.COOKIES:
return
stage: AuthenticatorValidateStage = self.executor.current_stage
threshold = timedelta_from_string(stage.last_auth_threshold)
threshold = timedelta_from_string(self.current_stage.last_auth_threshold)
latest_allowed = datetime.now() + threshold
try:
payload = decode(self.request.COOKIES[COOKIE_NAME_MFA], self.cookie_jwt_key, ["HS256"])
if payload["stage"] != stage.pk.hex:
if payload["stage"] != self.current_stage.pk.hex:
self.logger.warning("Invalid stage PK")
return
if datetime.fromtimestamp(payload["exp"]) > latest_allowed:
@ -385,15 +384,14 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"""Set an MFA cookie to allow users to skip MFA validation in this context (browser)
The cookie is JWT which is signed with a hash of the secret key and the UID of the stage"""
stage: AuthenticatorValidateStage = self.executor.current_stage
delta = timedelta_from_string(stage.last_auth_threshold)
delta = timedelta_from_string(self.current_stage.last_auth_threshold)
if delta.total_seconds() < 1:
self.logger.info("Not setting MFA cookie since threshold is not set.")
return self.executor.stage_ok()
expiry = datetime.now() + delta
cookie_payload = {
"device": device.pk,
"stage": stage.pk.hex,
"stage": self.current_stage.pk.hex,
"exp": expiry.timestamp(),
}
response = self.executor.stage_ok()

View File

@ -108,7 +108,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
return registration
class AuthenticatorWebAuthnStageView(ChallengeStageView):
class AuthenticatorWebAuthnStageView(ChallengeStageView[AuthenticatorWebAuthnStage]):
"""WebAuthn stage"""
response_class = AuthenticatorWebAuthnChallengeResponse
@ -116,12 +116,11 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
def get_challenge(self, *args, **kwargs) -> Challenge:
# clear session variables prior to starting a new registration
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
stage: AuthenticatorWebAuthnStage = self.executor.current_stage
user = self.get_pending_user()
# library accepts none so we store null in the database, but if there is a value
# set, cast it to string to ensure it's not a django class
authenticator_attachment = stage.authenticator_attachment
authenticator_attachment = self.current_stage.authenticator_attachment
if authenticator_attachment:
authenticator_attachment = AuthenticatorAttachment(str(authenticator_attachment))
@ -132,8 +131,12 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
user_name=user.username,
user_display_name=user.name,
authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement(str(stage.resident_key_requirement)),
user_verification=UserVerificationRequirement(str(stage.user_verification)),
resident_key=ResidentKeyRequirement(
str(self.current_stage.resident_key_requirement)
),
user_verification=UserVerificationRequirement(
str(self.current_stage.user_verification)
),
authenticator_attachment=authenticator_attachment,
),
attestation=AttestationConveyancePreference.DIRECT,

View File

@ -70,7 +70,7 @@ class CaptchaChallengeResponse(ChallengeResponse):
return data
class CaptchaStageView(ChallengeStageView):
class CaptchaStageView(ChallengeStageView[CaptchaChallenge]):
"""Simple captcha checker, logic is handled in django-captcha module"""
response_class = CaptchaChallengeResponse
@ -78,8 +78,8 @@ class CaptchaStageView(ChallengeStageView):
def get_challenge(self, *args, **kwargs) -> Challenge:
return CaptchaChallenge(
data={
"js_url": self.executor.current_stage.js_url,
"site_key": self.executor.current_stage.public_key,
"js_url": self.current_stage.js_url,
"site_key": self.current_stage.public_key,
}
)
@ -87,6 +87,6 @@ class CaptchaStageView(ChallengeStageView):
response = response.validated_data["token"]
self.executor.plan.context[PLAN_CONTEXT_CAPTCHA] = {
"response": response,
"stage": self.executor.current_stage,
"stage": self.current_stage,
}
return self.executor.stage_ok()

View File

@ -48,7 +48,7 @@ class ConsentChallengeResponse(ChallengeResponse):
token = CharField(required=True)
class ConsentStageView(ChallengeStageView):
class ConsentStageView(ChallengeStageView[ConsentStage]):
"""Simple consent checker."""
response_class = ConsentChallengeResponse
@ -72,14 +72,13 @@ class ConsentStageView(ChallengeStageView):
"""Check if the current request should require a prompt for non consent reasons,
i.e. this stage injected from another stage, mode is always requireed or no application
is set."""
current_stage: ConsentStage = self.executor.current_stage
# Make this StageView work when injected, in which case `current_stage` is an instance
# of the base class, and we don't save any consent, as it is assumed to be a one-time
# prompt
if not isinstance(current_stage, ConsentStage):
if not isinstance(self.current_stage, ConsentStage):
return True
# For always require, we always return the challenge
if current_stage.mode == ConsentMode.ALWAYS_REQUIRE:
if self.current_stage.mode == ConsentMode.ALWAYS_REQUIRE:
return True
# at this point we need to check consent from database
if PLAN_CONTEXT_APPLICATION not in self.executor.plan.context:
@ -125,7 +124,6 @@ class ConsentStageView(ChallengeStageView):
return self.get(self.request)
if self.should_always_prompt():
return self.executor.stage_ok()
current_stage: ConsentStage = self.executor.current_stage
application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION]
permissions = self.executor.plan.context.get(
PLAN_CONTEXT_CONSENT_PERMISSIONS, []
@ -139,9 +137,9 @@ class ConsentStageView(ChallengeStageView):
)
consent: UserConsent = self.executor.plan.context[PLAN_CONTEXT_CONSENT]
consent.permissions = permissions_string
if current_stage.mode == ConsentMode.PERMANENT:
if self.current_stage.mode == ConsentMode.PERMANENT:
consent.expiring = False
if current_stage.mode == ConsentMode.EXPIRING:
consent.expires = now() + timedelta_from_string(current_stage.consent_expire_in)
if self.current_stage.mode == ConsentMode.EXPIRING:
consent.expires = now() + timedelta_from_string(self.current_stage.consent_expire_in)
consent.save()
return self.executor.stage_ok()

View File

@ -6,11 +6,10 @@ from authentik.flows.stage import StageView
from authentik.stages.deny.models import DenyStage
class DenyStageView(StageView):
class DenyStageView(StageView[DenyStage]):
"""Cancels the current flow"""
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Cancels the current flow"""
stage: DenyStage = self.executor.current_stage
message = self.executor.plan.context.get("deny_message", stage.deny_message)
message = self.executor.plan.context.get("deny_message", self.current_stage.deny_message)
return self.executor.stage_invalid(message)

View File

@ -30,11 +30,11 @@ class DummyStageView(ChallengeStageView):
return self.executor.stage_ok()
def get_challenge(self, *args, **kwargs) -> Challenge:
if self.executor.current_stage.throw_error:
if self.current_stage.throw_error:
raise SentryIgnoredException("Test error")
return DummyChallenge(
data={
"title": self.executor.current_stage.name,
"name": self.executor.current_stage.name,
"title": self.current_stage.name,
"name": self.current_stage.name,
}
)

View File

@ -46,7 +46,7 @@ class EmailChallengeResponse(ChallengeResponse):
raise ValidationError(detail="email-sent", code="email-sent")
class EmailStageView(ChallengeStageView):
class EmailStageView(ChallengeStageView[EmailStage]):
"""Email stage which sends Email for verification"""
response_class = EmailChallengeResponse
@ -72,11 +72,10 @@ class EmailStageView(ChallengeStageView):
def get_token(self) -> FlowToken:
"""Get token"""
pending_user = self.get_pending_user()
current_stage: EmailStage = self.executor.current_stage
valid_delta = timedelta(
minutes=current_stage.token_expiry + 1
minutes=self.current_stage.token_expiry + 1
) # + 1 because django timesince always rounds down
identifier = slugify(f"ak-email-stage-{current_stage.name}-{str(uuid4())}")
identifier = slugify(f"ak-email-stage-{self.current_stage.name}-{str(uuid4())}")
# Don't check for validity here, we only care if the token exists
tokens = FlowToken.objects.filter(identifier=identifier)
if not tokens.exists():
@ -105,15 +104,14 @@ class EmailStageView(ChallengeStageView):
email = self.executor.plan.context.get(PLAN_CONTEXT_EMAIL_OVERRIDE, None)
if not email:
email = pending_user.email
current_stage: EmailStage = self.executor.current_stage
token = self.get_token()
# Send mail to user
try:
message = TemplateEmailMessage(
subject=_(current_stage.subject),
subject=_(self.current_stage.subject),
to=[(pending_user.name, email)],
language=pending_user.locale(self.request),
template_name=current_stage.template,
template_name=self.current_stage.template,
template_context={
"url": self.get_full_url(**{QS_KEY_TOKEN: token.key}),
"user": pending_user,
@ -121,26 +119,28 @@ class EmailStageView(ChallengeStageView):
"token": token.key,
},
)
send_mails(current_stage, message)
send_mails(self.current_stage, message)
except TemplateSyntaxError as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
message=_("Exception occurred while rendering E-mail template"),
error=exception_to_string(exc),
template=current_stage.template,
template=self.current_stage.template,
).from_http(self.request)
raise StageInvalidException from exc
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
# Check if the user came back from the email link to verify
restore_token: FlowToken = self.executor.plan.context.get(PLAN_CONTEXT_IS_RESTORED, None)
restore_token: FlowToken | None = self.executor.plan.context.get(
PLAN_CONTEXT_IS_RESTORED, None
)
user = self.get_pending_user()
if restore_token:
if restore_token.user != user:
self.logger.warning("Flow token for non-matching user, denying request")
return self.executor.stage_invalid()
messages.success(request, _("Successfully verified Email."))
if self.executor.current_stage.activate_user_on_success:
if self.current_stage.activate_user_on_success:
user.is_active = True
user.save()
return self.executor.stage_ok()

View File

@ -27,6 +27,7 @@ class IdentificationStageSerializer(StageSerializer):
fields = StageSerializer.Meta.fields + [
"user_fields",
"password_stage",
"captcha_stage",
"case_insensitive_matching",
"show_matched_user",
"enrollment_flow",

View File

@ -0,0 +1,26 @@
# Generated by Django 5.0.8 on 2024-08-24 12:58
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_stages_captcha", "0003_captchastage_error_on_invalid_score_and_more"),
("authentik_stages_identification", "0014_identificationstage_pretend"),
]
operations = [
migrations.AddField(
model_name="identificationstage",
name="captcha_stage",
field=models.ForeignKey(
default=None,
help_text="When set, the captcha element is shown on the identification stage.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="authentik_stages_captcha.captchastage",
),
),
]

View File

@ -8,6 +8,7 @@ from rest_framework.serializers import BaseSerializer
from authentik.core.models import Source
from authentik.flows.models import Flow, Stage
from authentik.stages.captcha.models import CaptchaStage
from authentik.stages.password.models import PasswordStage
@ -42,6 +43,15 @@ class IdentificationStage(Stage):
),
),
)
captcha_stage = models.ForeignKey(
CaptchaStage,
null=True,
default=None,
on_delete=models.SET_NULL,
help_text=_(
("When set, the captcha element is shown on the identification stage."),
),
)
case_insensitive_matching = models.BooleanField(
default=True,

View File

@ -30,9 +30,14 @@ from authentik.lib.utils.urls import reverse_with_qs
from authentik.root.middleware import ClientIPMiddleware
from authentik.sources.oauth.types.apple import AppleLoginChallenge
from authentik.sources.plex.models import PlexAuthenticationChallenge
from authentik.stages.captcha.stage import (
CaptchaChallenge,
CaptchaChallengeResponse,
CaptchaStageView,
)
from authentik.stages.identification.models import IdentificationStage
from authentik.stages.identification.signals import identification_failed
from authentik.stages.password.stage import authenticate
from authentik.stages.password.stage import PasswordChallenge, PasswordStageView, authenticate
@extend_schema_field(
@ -63,8 +68,8 @@ class IdentificationChallenge(Challenge):
"""Identification challenges with all UI elements"""
user_fields = ListField(child=CharField(), allow_empty=True, allow_null=True)
password_fields = BooleanField()
allow_show_password = BooleanField(default=False)
password_stage = PasswordChallenge(required=False)
captcha_stage = CaptchaChallenge(required=False)
application_pre = CharField(required=False)
flow_designation = ChoiceField(FlowDesignation.choices)
@ -84,6 +89,7 @@ class IdentificationChallengeResponse(ChallengeResponse):
uid_field = CharField()
password = CharField(required=False, allow_blank=True, allow_null=True)
component = CharField(default="ak-stage-identification")
captcha = CaptchaChallengeResponse(required=False)
pre_user: User | None = None
@ -128,49 +134,50 @@ class IdentificationChallengeResponse(ChallengeResponse):
return attrs
raise ValidationError("Failed to authenticate.")
self.pre_user = pre_user
if not current_stage.password_stage:
# No password stage select, don't validate the password
return attrs
password = attrs.get("password", None)
if not password:
self.stage.logger.warning("Password not set for ident+auth attempt")
try:
with start_span(
op="authentik.stages.identification.authenticate",
description="User authenticate call (combo stage)",
):
user = authenticate(
self.stage.request,
current_stage.password_stage.backends,
current_stage,
username=self.pre_user.username,
password=password,
)
if not user:
raise ValidationError("Failed to authenticate.")
self.pre_user = user
except PermissionDenied as exc:
raise ValidationError(str(exc)) from exc
if current_stage.password_stage:
password = attrs.get("password", None)
if not password:
self.stage.logger.warning("Password not set for ident+auth attempt")
try:
with start_span(
op="authentik.stages.identification.authenticate",
description="User authenticate call (combo stage)",
):
user = authenticate(
self.stage.request,
current_stage.password_stage.backends,
current_stage,
username=self.pre_user.username,
password=password,
)
if not user:
raise ValidationError("Failed to authenticate.")
self.pre_user = user
except PermissionDenied as exc:
raise ValidationError(str(exc)) from exc
print(attrs)
# if current_stage.captcha_stage:
# captcha = CaptchaStageView(self.stage.executor)
# captcha.stage = current_stage.captcha_stage
# captcha.challenge_valid(attrs.get("captcha"))
return attrs
class IdentificationStageView(ChallengeStageView):
class IdentificationStageView(ChallengeStageView[IdentificationStage]):
"""Form to identify the user"""
response_class = IdentificationChallengeResponse
def get_user(self, uid_value: str) -> User | None:
"""Find user instance. Returns None if no user was found."""
current_stage: IdentificationStage = self.executor.current_stage
query = Q()
for search_field in current_stage.user_fields:
for search_field in self.current_stage.user_fields:
model_field = {
"email": "email",
"username": "username",
"upn": "attributes__upn",
}[search_field]
if current_stage.case_insensitive_matching:
if self.current_stage.case_insensitive_matching:
model_field += "__iexact"
else:
model_field += "__exact"
@ -191,16 +198,12 @@ class IdentificationStageView(ChallengeStageView):
return _("Continue")
def get_challenge(self) -> Challenge:
current_stage: IdentificationStage = self.executor.current_stage
challenge = IdentificationChallenge(
data={
"component": "ak-stage-identification",
"primary_action": self.get_primary_action(),
"user_fields": current_stage.user_fields,
"password_fields": bool(current_stage.password_stage),
"allow_show_password": bool(current_stage.password_stage)
and current_stage.password_stage.allow_show_password,
"show_source_labels": current_stage.show_source_labels,
"user_fields": self.current_stage.user_fields,
"show_source_labels": self.current_stage.show_source_labels,
"flow_designation": self.executor.flow.designation,
}
)
@ -212,29 +215,39 @@ class IdentificationStageView(ChallengeStageView):
).name
get_qs = self.request.session.get(SESSION_KEY_GET, self.request.GET)
# Check for related enrollment and recovery flow, add URL to view
if current_stage.enrollment_flow:
if self.current_stage.enrollment_flow:
challenge.initial_data["enroll_url"] = reverse_with_qs(
"authentik_core:if-flow",
query=get_qs,
kwargs={"flow_slug": current_stage.enrollment_flow.slug},
kwargs={"flow_slug": self.current_stage.enrollment_flow.slug},
)
if current_stage.recovery_flow:
if self.current_stage.recovery_flow:
challenge.initial_data["recovery_url"] = reverse_with_qs(
"authentik_core:if-flow",
query=get_qs,
kwargs={"flow_slug": current_stage.recovery_flow.slug},
kwargs={"flow_slug": self.current_stage.recovery_flow.slug},
)
if current_stage.passwordless_flow:
if self.current_stage.passwordless_flow:
challenge.initial_data["passwordless_url"] = reverse_with_qs(
"authentik_core:if-flow",
query=get_qs,
kwargs={"flow_slug": current_stage.passwordless_flow.slug},
kwargs={"flow_slug": self.current_stage.passwordless_flow.slug},
)
if self.current_stage.password_stage:
password = PasswordStageView(self.executor, self.current_stage.captcha_stage)
password_challenge = password.get_challenge()
password_challenge.is_valid()
challenge.initial_data["password_stage"] = password_challenge.data
if self.current_stage.captcha_stage:
captcha = CaptchaStageView(self.executor, self.current_stage.captcha_stage)
captcha_challenge = captcha.get_challenge()
captcha_challenge.is_valid()
challenge.initial_data["captcha_stage"] = captcha_challenge.data
# Check all enabled source, add them if they have a UI Login button.
ui_sources = []
sources: list[Source] = (
current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
self.current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
)
for source in sources:
ui_login_button = source.ui_login_button(self.request)
@ -249,8 +262,7 @@ class IdentificationStageView(ChallengeStageView):
def challenge_valid(self, response: IdentificationChallengeResponse) -> HttpResponse:
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = response.pre_user
current_stage: IdentificationStage = self.executor.current_stage
if not current_stage.show_matched_user:
if not self.current_stage.show_matched_user:
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = (
response.validated_data.get("uid_field")
)

View File

@ -17,7 +17,7 @@ INVITATION_IN_EFFECT = "invitation_in_effect"
INVITATION = "invitation"
class InvitationStageView(StageView):
class InvitationStageView(StageView[InvitationStage]):
"""Finalise Authentication flow by logging the user in"""
def get_token(self) -> str | None:
@ -52,11 +52,10 @@ class InvitationStageView(StageView):
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Apply data to the current flow based on a URL"""
stage: InvitationStage = self.executor.current_stage
invite = self.get_invite()
if not invite:
if stage.continue_flow_without_invitation:
if self.current_stage.continue_flow_without_invitation:
return self.executor.stage_ok()
return self.executor.stage_invalid(_("Invalid invite/invite not found"))

View File

@ -130,7 +130,7 @@ class PasswordChallengeResponse(ChallengeResponse):
return password
class PasswordStageView(ChallengeStageView):
class PasswordStageView(ChallengeStageView[PasswordStage]):
"""Authentication stage which authenticates against django's AuthBackend"""
response_class = PasswordChallengeResponse
@ -138,7 +138,7 @@ class PasswordStageView(ChallengeStageView):
def get_challenge(self) -> Challenge:
challenge = PasswordChallenge(
data={
"allow_show_password": self.executor.current_stage.allow_show_password,
"allow_show_password": self.current_stage.allow_show_password,
}
)
recovery_flow = Flow.objects.filter(designation=FlowDesignation.RECOVERY)
@ -154,10 +154,9 @@ class PasswordStageView(ChallengeStageView):
if SESSION_KEY_INVALID_TRIES not in self.request.session:
self.request.session[SESSION_KEY_INVALID_TRIES] = 0
self.request.session[SESSION_KEY_INVALID_TRIES] += 1
current_stage: PasswordStage = self.executor.current_stage
if (
self.request.session[SESSION_KEY_INVALID_TRIES]
>= current_stage.failed_attempts_before_cancel
>= self.current_stage.failed_attempts_before_cancel
):
self.logger.debug("User has exceeded maximum tries")
del self.request.session[SESSION_KEY_INVALID_TRIES]

View File

@ -222,7 +222,7 @@ class PromptStageView(ChallengeStageView):
return serializers
def get_challenge(self, *args, **kwargs) -> Challenge:
fields: list[Prompt] = list(self.executor.current_stage.fields.all().order_by("order"))
fields: list[Prompt] = list(self.current_stage.fields.all().order_by("order"))
context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {})
serializers = self.get_prompt_challenge_fields(fields, context_prompt)
challenge = PromptChallenge(
@ -239,7 +239,7 @@ class PromptStageView(ChallengeStageView):
instance=None,
data=data,
request=self.request,
stage_instance=self.executor.current_stage,
stage_instance=self.current_stage,
stage=self,
plan=self.executor.plan,
user=self.get_pending_user(),

View File

@ -7,9 +7,10 @@ from django.utils.translation import gettext as _
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import StageView
from authentik.stages.user_delete.models import UserDeleteStage
class UserDeleteStageView(StageView):
class UserDeleteStageView(StageView[UserDeleteStage]):
"""Finalise unenrollment flow by deleting the user object."""
def dispatch(self, request: HttpRequest) -> HttpResponse:

View File

@ -39,7 +39,7 @@ class UserLoginChallengeResponse(ChallengeResponse):
remember_me = BooleanField(required=True)
class UserLoginStageView(ChallengeStageView):
class UserLoginStageView(ChallengeStageView[UserLoginStage]):
"""Finalise Authentication flow by logging the user in"""
response_class = UserLoginChallengeResponse
@ -49,8 +49,7 @@ class UserLoginStageView(ChallengeStageView):
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Check for remember_me, and do login"""
stage: UserLoginStage = self.executor.current_stage
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
if timedelta_from_string(self.current_stage.remember_me_offset).total_seconds() > 0:
return super().dispatch(request)
return self.do_login(request)
@ -59,9 +58,9 @@ class UserLoginStageView(ChallengeStageView):
def set_session_duration(self, remember: bool) -> timedelta:
"""Update the sessions' expiry"""
delta = timedelta_from_string(self.executor.current_stage.session_duration)
delta = timedelta_from_string(self.current_stage.session_duration)
if remember:
offset = timedelta_from_string(self.executor.current_stage.remember_me_offset)
offset = timedelta_from_string(self.current_stage.remember_me_offset)
delta = delta + offset
if delta.total_seconds() == 0:
self.request.session.set_expiry(0)
@ -71,11 +70,9 @@ class UserLoginStageView(ChallengeStageView):
def set_session_ip(self):
"""Set the sessions' last IP and session bindings"""
stage: UserLoginStage = self.executor.current_stage
self.request.session[SESSION_KEY_LAST_IP] = ClientIPMiddleware.get_client_ip(self.request)
self.request.session[SESSION_KEY_BINDING_NET] = stage.network_binding
self.request.session[SESSION_KEY_BINDING_GEO] = stage.geoip_binding
self.request.session[SESSION_KEY_BINDING_NET] = self.current_stage.network_binding
self.request.session[SESSION_KEY_BINDING_GEO] = self.current_stage.geoip_binding
def do_login(self, request: HttpRequest, remember: bool = False) -> HttpResponse:
"""Attach the currently pending user to the current session"""
@ -111,7 +108,7 @@ class UserLoginStageView(ChallengeStageView):
# as sources show their own success messages
if not self.executor.plan.context.get(PLAN_CONTEXT_SOURCE, None):
messages.success(self.request, _("Successfully logged in!"))
if self.executor.current_stage.terminate_other_sessions:
if self.current_stage.terminate_other_sessions:
AuthenticatedSession.objects.filter(
user=user,
).exclude(session_key=self.request.session.session_key).delete()

View File

@ -4,9 +4,10 @@ from django.contrib.auth import logout
from django.http import HttpRequest, HttpResponse
from authentik.flows.stage import StageView
from authentik.stages.user_logout.models import UserLogoutStage
class UserLogoutStageView(StageView):
class UserLogoutStageView(StageView[UserLogoutStage]):
"""Finalise Authentication flow by logging the user in"""
def dispatch(self, request: HttpRequest) -> HttpResponse:

View File

@ -55,7 +55,7 @@ class UserWriteStageView(StageView):
"""Ensure a user exists"""
user_created = False
path = self.executor.plan.context.get(
PLAN_CONTEXT_USER_PATH, self.executor.current_stage.user_path_template
PLAN_CONTEXT_USER_PATH, self.current_stage.user_path_template
)
if path == "":
path = User.default_path()
@ -64,11 +64,11 @@ class UserWriteStageView(StageView):
user_type = UserTypes(
self.executor.plan.context.get(
PLAN_CONTEXT_USER_TYPE,
self.executor.current_stage.user_type,
self.current_stage.user_type,
)
)
except ValueError:
user_type = self.executor.current_stage.user_type
user_type = self.current_stage.user_type
if user_type == UserTypes.INTERNAL_SERVICE_ACCOUNT:
user_type = UserTypes.SERVICE_ACCOUNT
@ -76,12 +76,12 @@ class UserWriteStageView(StageView):
self.executor.plan.context.setdefault(PLAN_CONTEXT_PENDING_USER, self.request.user)
if (
PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context
or self.executor.current_stage.user_creation_mode == UserCreationMode.ALWAYS_CREATE
or self.current_stage.user_creation_mode == UserCreationMode.ALWAYS_CREATE
):
if self.executor.current_stage.user_creation_mode == UserCreationMode.NEVER_CREATE:
if self.current_stage.user_creation_mode == UserCreationMode.NEVER_CREATE:
return None, False
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User(
is_active=not self.executor.current_stage.create_users_as_inactive,
is_active=not self.current_stage.create_users_as_inactive,
path=path,
type=user_type,
)
@ -180,8 +180,8 @@ class UserWriteStageView(StageView):
try:
with transaction.atomic():
user.save()
if self.executor.current_stage.create_users_group:
user.ak_groups.add(self.executor.current_stage.create_users_group)
if self.current_stage.create_users_group:
user.ak_groups.add(self.current_stage.create_users_group)
if PLAN_CONTEXT_GROUPS in self.executor.plan.context:
user.ak_groups.add(*self.executor.plan.context[PLAN_CONTEXT_GROUPS])
except (IntegrityError, ValueError, TypeError, InternalError) as exc:

View File

@ -82,5 +82,3 @@ entries:
order: 10
target: !KeyOf default-authentication-flow-password-binding
policy: !KeyOf default-authentication-flow-password-optional
attrs:
failure_result: true

View File

@ -2,7 +2,7 @@
"$schema": "http://json-schema.org/draft-07/schema",
"$id": "https://goauthentik.io/blueprints/schema.json",
"type": "object",
"title": "authentik 2024.8.4 Blueprint schema",
"title": "authentik 2024.6.4 Blueprint schema",
"required": [
"version",
"entries"
@ -10091,6 +10091,11 @@
"title": "Password stage",
"description": "When set, shows a password field, instead of showing the password field as separate step."
},
"captcha_stage": {
"type": "integer",
"title": "Captcha stage",
"description": "When set, the captcha element is shown on the identification stage."
},
"case_insensitive_matching": {
"type": "boolean",
"title": "Case insensitive matching",

View File

@ -31,7 +31,7 @@ services:
volumes:
- redis:/data
server:
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.4}
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4}
restart: unless-stopped
command: server
environment:
@ -52,7 +52,7 @@ services:
- postgresql
- redis
worker:
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.8.4}
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.6.4}
restart: unless-stopped
command: worker
environment:

6
go.mod
View File

@ -18,18 +18,18 @@ require (
github.com/gorilla/securecookie v1.1.2
github.com/gorilla/sessions v1.4.0
github.com/gorilla/websocket v1.5.3
github.com/jellydator/ttlcache/v3 v3.2.1
github.com/jellydator/ttlcache/v3 v3.2.0
github.com/mitchellh/mapstructure v1.5.0
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.20.2
github.com/prometheus/client_golang v1.20.1
github.com/redis/go-redis/v9 v9.6.1
github.com/sethvargo/go-envconfig v1.1.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0
github.com/wwt/guac v1.3.2
goauthentik.io/api/v3 v3.2024064.1
goauthentik.io/api/v3 v3.2024063.13
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.22.0
golang.org/x/sync v0.8.0

16
go.sum
View File

@ -200,8 +200,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/jellydator/ttlcache/v3 v3.2.1 h1:eS8ljnYY7BllYGkXw/TfczWZrXUu/CH7SIkC6ugn9Js=
github.com/jellydator/ttlcache/v3 v3.2.1/go.mod h1:bj2/e0l4jRnQdrnSTaGTsh4GSXvMjQcy41i7th0GVGw=
github.com/jellydator/ttlcache/v3 v3.2.0 h1:6lqVJ8X3ZaUwvzENqPAobDsXNExfUJd61u++uW8a3LE=
github.com/jellydator/ttlcache/v3 v3.2.0/go.mod h1:hi7MGFdMAwZna5n2tuvh63DvFLzVKySzCVW6+0gA2n4=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
@ -239,8 +239,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.20.2 h1:5ctymQzZlyOON1666svgwn3s6IKWgfbjsejTMiXIyjg=
github.com/prometheus/client_golang v1.20.2/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_golang v1.20.1 h1:IMJXHOD6eARkQpxo8KkhgEVFlBNm+nkrFUyGlIu7Na8=
github.com/prometheus/client_golang v1.20.1/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
@ -297,10 +297,10 @@ go.opentelemetry.io/otel/sdk v1.24.0 h1:YMPPDNymmQN3ZgczicBY3B6sf9n62Dlj9pWD3ucg
go.opentelemetry.io/otel/sdk v1.24.0/go.mod h1:KVrIYw6tEubO9E96HQpcmpTKDVn9gdv35HoYiQWGDFg=
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
goauthentik.io/api/v3 v3.2024064.1 h1:vxquklgDGD+nGFhWRAsQ7ezQKg17MRq6bzEk25fbsb4=
goauthentik.io/api/v3 v3.2024064.1/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
goauthentik.io/api/v3 v3.2024063.13 h1:zWFlrr+8NOaQOCPSRV1FhbDJ58+BPa9BqjNvl4T//s8=
goauthentik.io/api/v3 v3.2024063.13/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=

View File

@ -29,4 +29,4 @@ func UserAgent() string {
return fmt.Sprintf("authentik@%s", FullVersion())
}
const VERSION = "2024.8.4"
const VERSION = "2024.6.4"

View File

@ -35,11 +35,10 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]](
req PaginatorRequest[Treq, Tres],
opts PaginatorOptions,
) ([]Tobj, error) {
var bfreq, cfreq interface{}
fetchOffset := func(page int32) (Tres, error) {
bfreq = req.Page(page)
cfreq = bfreq.(PaginatorRequest[Treq, Tres]).PageSize(int32(opts.PageSize))
res, _, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute()
req.Page(page)
req.PageSize(int32(opts.PageSize))
res, _, err := req.Execute()
if err != nil {
opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page")
}

View File

@ -1,26 +0,0 @@
package ak
// func Test_PaginatorCompile(t *testing.T) {
// req := api.ApiCoreUsersListRequest{}
// Paginator(req, PaginatorOptions{
// PageSize: 100,
// })
// }
// func Test_PaginatorCompileExplicit(t *testing.T) {
// req := api.ApiCoreUsersListRequest{}
// Paginator[
// api.User,
// api.ApiCoreUsersListRequest,
// *api.PaginatedUserList,
// ](req, PaginatorOptions{
// PageSize: 100,
// })
// }
// func Test_PaginatorCompileOther(t *testing.T) {
// req := api.ApiOutpostsProxyListRequest{}
// Paginator(req, PaginatorOptions{
// PageSize: 100,
// })
// }

View File

@ -96,7 +96,7 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
return ldap.LDAPResultOperationsError, nil
}
flags.UserPk = userInfo.User.Pk
flags.CanSearch = access.GetHasSearchPermission()
flags.CanSearch = access.HasSearchPermission != nil
db.si.SetFlags(req.BindDN, &flags)
if flags.CanSearch {
req.Log().Debug("Allowed access to search")

View File

@ -193,17 +193,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server) (*A
})
mux.HandleFunc("/outpost.goauthentik.io/start", func(w http.ResponseWriter, r *http.Request) {
fwd := ""
// This should only really be hit for nginx forward_auth
// as for that the auth start redirect URL is generated by the
// reverse proxy, and as such we won't have a request we just
// denied to reference for final URL
rd, ok := a.checkRedirectParam(r)
if ok {
a.log.WithField("rd", rd).Trace("Setting redirect")
fwd = rd
}
a.handleAuthStart(w, r, fwd)
a.handleAuthStart(w, r, "")
})
mux.HandleFunc("/outpost.goauthentik.io/callback", a.handleAuthCallback)
mux.HandleFunc("/outpost.goauthentik.io/sign_out", a.handleSignOut)

View File

@ -15,6 +15,36 @@ const (
LogoutSignature = "X-authentik-logout"
)
func (a *Application) checkRedirectParam(r *http.Request) (string, bool) {
rd := r.URL.Query().Get(redirectParam)
if rd == "" {
return "", false
}
u, err := url.Parse(rd)
if err != nil {
a.log.WithError(err).Warning("Failed to parse redirect URL")
return "", false
}
// Check to make sure we only redirect to allowed places
if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE {
ext, err := url.Parse(a.proxyConfig.ExternalHost)
if err != nil {
return "", false
}
ext.Scheme = ""
if !strings.Contains(u.String(), ext.String()) {
a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host")
return "", false
}
} else {
if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) {
a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain")
return "", false
}
}
return u.String(), true
}
func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, fwd string) {
state, err := a.createState(r, fwd)
if err != nil {

View File

@ -5,13 +5,10 @@ import (
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/securecookie"
"github.com/mitchellh/mapstructure"
"goauthentik.io/api/v3"
)
type OAuthState struct {
@ -30,44 +27,6 @@ func (oas *OAuthState) GetAudience() (jwt.ClaimStrings, error) { return ni
var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)
// Validate that the given redirect parameter (?rd=...) is valid and can be used
// For proxy/forward_single this checks that if the `rd` param has a Hostname (and is a full URL)
// the hostname matches what's configured, or no hostname must be given
// For forward_domain this checks if the domain of the URL in `rd` ends with the configured domain
func (a *Application) checkRedirectParam(r *http.Request) (string, bool) {
rd := r.URL.Query().Get(redirectParam)
if rd == "" {
return "", false
}
u, err := url.Parse(rd)
if err != nil {
a.log.WithError(err).Warning("Failed to parse redirect URL")
return "", false
}
// Check to make sure we only redirect to allowed places
if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE {
ext, err := url.Parse(a.proxyConfig.ExternalHost)
if err != nil {
return "", false
}
// Either hostname needs to match the configured domain, or host name must be empty for just a path
if u.Host == "" {
u.Host = ext.Host
u.Scheme = ext.Scheme
}
if u.Host != ext.Host {
a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host")
return "", false
}
} else {
if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) {
a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain")
return "", false
}
}
return u.String(), true
}
func (a *Application) createState(r *http.Request, fwd string) (string, error) {
s, _ := a.sessions.Get(r, a.SessionName())
if s.ID == "" {
@ -80,6 +39,17 @@ func (a *Application) createState(r *http.Request, fwd string) (string, error) {
SessionID: s.ID,
Redirect: fwd,
}
if fwd == "" {
// This should only really be hit for nginx forward_auth
// as for that the auth start redirect URL is generated by the
// reverse proxy, and as such we won't have a request we just
// denied to reference for final URL
rd, ok := a.checkRedirectParam(r)
if ok {
a.log.WithField("rd", rd).Trace("Setting redirect")
st.Redirect = rd
}
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, st)
tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret()))
if err != nil {

View File

@ -8,45 +8,25 @@ import (
"goauthentik.io/api/v3"
)
func TestCheckRedirectParam_None(t *testing.T) {
func TestCheckRedirectParam(t *testing.T) {
a := newTestApplication()
// Test no rd param
req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start", nil)
rd, ok := a.checkRedirectParam(req)
assert.Equal(t, false, ok)
assert.Equal(t, "", rd)
}
func TestCheckRedirectParam_Invalid(t *testing.T) {
a := newTestApplication()
// Test invalid rd param
req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil)
req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://google.com", nil)
rd, ok := a.checkRedirectParam(req)
rd, ok = a.checkRedirectParam(req)
assert.Equal(t, false, ok)
assert.Equal(t, "", rd)
}
func TestCheckRedirectParam_ValidFull(t *testing.T) {
a := newTestApplication()
// Test valid full rd param
req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil)
req, _ = http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=https://ext.t.goauthentik.io/test?foo", nil)
rd, ok := a.checkRedirectParam(req)
assert.Equal(t, true, ok)
assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd)
}
func TestCheckRedirectParam_ValidPartial(t *testing.T) {
a := newTestApplication()
// Test valid partial rd param
req, _ := http.NewRequest("GET", "/outpost.goauthentik.io/auth/start?rd=/test?foo", nil)
rd, ok := a.checkRedirectParam(req)
rd, ok = a.checkRedirectParam(req)
assert.Equal(t, true, ok)
assert.Equal(t, "https://ext.t.goauthentik.io/test?foo", rd)

View File

@ -6,7 +6,6 @@ import (
"errors"
"net"
"net/http"
"strings"
"sync"
sentryhttp "github.com/getsentry/sentry-go/http"
@ -71,20 +70,12 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer {
}
func (ps *ProxyServer) HandleHost(rw http.ResponseWriter, r *http.Request) bool {
// Always handle requests for outpost paths that should answer regardless of hostname
if strings.HasPrefix(r.URL.Path, "/outpost.goauthentik.io/ping") ||
strings.HasPrefix(r.URL.Path, "/outpost.goauthentik.io/static") {
ps.mux.ServeHTTP(rw, r)
return true
}
// lookup app by hostname
a, _ := ps.lookupApp(r)
if a == nil {
return false
}
// check if the app should handle this URL, or is setup in proxy mode
if a.ShouldHandleURL(r) || a.Mode() == api.PROXYMODE_PROXY {
ps.mux.ServeHTTP(rw, r)
a.ServeHTTP(rw, r)
return true
}
return false

View File

@ -1,5 +1,5 @@
{
"name": "@goauthentik/authentik",
"version": "2024.8.4",
"version": "2024.6.4",
"private": true
}

66
poetry.lock generated
View File

@ -1053,38 +1053,38 @@ toml = ["tomli"]
[[package]]
name = "cryptography"
version = "43.0.1"
version = "43.0.0"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
optional = false
python-versions = ">=3.7"
files = [
{file = "cryptography-43.0.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d"},
{file = "cryptography-43.0.1-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062"},
{file = "cryptography-43.0.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962"},
{file = "cryptography-43.0.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277"},
{file = "cryptography-43.0.1-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a"},
{file = "cryptography-43.0.1-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042"},
{file = "cryptography-43.0.1-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494"},
{file = "cryptography-43.0.1-cp37-abi3-win32.whl", hash = "sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2"},
{file = "cryptography-43.0.1-cp37-abi3-win_amd64.whl", hash = "sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d"},
{file = "cryptography-43.0.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d"},
{file = "cryptography-43.0.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806"},
{file = "cryptography-43.0.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85"},
{file = "cryptography-43.0.1-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c"},
{file = "cryptography-43.0.1-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1"},
{file = "cryptography-43.0.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa"},
{file = "cryptography-43.0.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4"},
{file = "cryptography-43.0.1-cp39-abi3-win32.whl", hash = "sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47"},
{file = "cryptography-43.0.1-cp39-abi3-win_amd64.whl", hash = "sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb"},
{file = "cryptography-43.0.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ea25acb556320250756e53f9e20a4177515f012c9eaea17eb7587a8c4d8ae034"},
{file = "cryptography-43.0.1-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c1332724be35d23a854994ff0b66530119500b6053d0bd3363265f7e5e77288d"},
{file = "cryptography-43.0.1-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fba1007b3ef89946dbbb515aeeb41e30203b004f0b4b00e5e16078b518563289"},
{file = "cryptography-43.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5b43d1ea6b378b54a1dc99dd8a2b5be47658fe9a7ce0a58ff0b55f4b43ef2b84"},
{file = "cryptography-43.0.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:88cce104c36870d70c49c7c8fd22885875d950d9ee6ab54df2745f83ba0dc365"},
{file = "cryptography-43.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9d3cdb25fa98afdd3d0892d132b8d7139e2c087da1712041f6b762e4f807cc96"},
{file = "cryptography-43.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e710bf40870f4db63c3d7d929aa9e09e4e7ee219e703f949ec4073b4294f6172"},
{file = "cryptography-43.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7c05650fe8023c5ed0d46793d4b7d7e6cd9c04e68eabe5b0aeea836e37bdcec2"},
{file = "cryptography-43.0.1.tar.gz", hash = "sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d"},
{file = "cryptography-43.0.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:64c3f16e2a4fc51c0d06af28441881f98c5d91009b8caaff40cf3548089e9c74"},
{file = "cryptography-43.0.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3dcdedae5c7710b9f97ac6bba7e1052b95c7083c9d0e9df96e02a1932e777895"},
{file = "cryptography-43.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d9a1eca329405219b605fac09ecfc09ac09e595d6def650a437523fcd08dd22"},
{file = "cryptography-43.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ea9e57f8ea880eeea38ab5abf9fbe39f923544d7884228ec67d666abd60f5a47"},
{file = "cryptography-43.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9a8d6802e0825767476f62aafed40532bd435e8a5f7d23bd8b4f5fd04cc80ecf"},
{file = "cryptography-43.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cc70b4b581f28d0a254d006f26949245e3657d40d8857066c2ae22a61222ef55"},
{file = "cryptography-43.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4a997df8c1c2aae1e1e5ac49c2e4f610ad037fc5a3aadc7b64e39dea42249431"},
{file = "cryptography-43.0.0-cp37-abi3-win32.whl", hash = "sha256:6e2b11c55d260d03a8cf29ac9b5e0608d35f08077d8c087be96287f43af3ccdc"},
{file = "cryptography-43.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:31e44a986ceccec3d0498e16f3d27b2ee5fdf69ce2ab89b52eaad1d2f33d8778"},
{file = "cryptography-43.0.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:7b3f5fe74a5ca32d4d0f302ffe6680fcc5c28f8ef0dc0ae8f40c0f3a1b4fca66"},
{file = "cryptography-43.0.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac1955ce000cb29ab40def14fd1bbfa7af2017cca696ee696925615cafd0dce5"},
{file = "cryptography-43.0.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:299d3da8e00b7e2b54bb02ef58d73cd5f55fb31f33ebbf33bd00d9aa6807df7e"},
{file = "cryptography-43.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ee0c405832ade84d4de74b9029bedb7b31200600fa524d218fc29bfa371e97f5"},
{file = "cryptography-43.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb013933d4c127349b3948aa8aaf2f12c0353ad0eccd715ca789c8a0f671646f"},
{file = "cryptography-43.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fdcb265de28585de5b859ae13e3846a8e805268a823a12a4da2597f1f5afc9f0"},
{file = "cryptography-43.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2905ccf93a8a2a416f3ec01b1a7911c3fe4073ef35640e7ee5296754e30b762b"},
{file = "cryptography-43.0.0-cp39-abi3-win32.whl", hash = "sha256:47ca71115e545954e6c1d207dd13461ab81f4eccfcb1345eac874828b5e3eaaf"},
{file = "cryptography-43.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:0663585d02f76929792470451a5ba64424acc3cd5227b03921dab0e2f27b1709"},
{file = "cryptography-43.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2c6d112bf61c5ef44042c253e4859b3cbbb50df2f78fa8fae6747a7814484a70"},
{file = "cryptography-43.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:844b6d608374e7d08f4f6e6f9f7b951f9256db41421917dfb2d003dde4cd6b66"},
{file = "cryptography-43.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:51956cf8730665e2bdf8ddb8da0056f699c1a5715648c1b0144670c1ba00b48f"},
{file = "cryptography-43.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:aae4d918f6b180a8ab8bf6511a419473d107df4dbb4225c7b48c5c9602c38c7f"},
{file = "cryptography-43.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:232ce02943a579095a339ac4b390fbbe97f5b5d5d107f8a08260ea2768be8cc2"},
{file = "cryptography-43.0.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:5bcb8a5620008a8034d39bce21dc3e23735dfdb6a33a06974739bfa04f853947"},
{file = "cryptography-43.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:08a24a7070b2b6804c1940ff0f910ff728932a9d0e80e7814234269f9d46d069"},
{file = "cryptography-43.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:e9c5266c432a1e23738d178e51c2c7a5e2ddf790f248be939448c0ba2021f9d1"},
{file = "cryptography-43.0.0.tar.gz", hash = "sha256:b88075ada2d51aa9f18283532c9f60e72170041bba88d7f37e49cbb10275299e"},
]
[package.dependencies]
@ -1097,7 +1097,7 @@ nox = ["nox"]
pep8test = ["check-sdist", "click", "mypy", "ruff"]
sdist = ["build"]
ssh = ["bcrypt (>=3.1.5)"]
test = ["certifi", "cryptography-vectors (==43.0.1)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
test = ["certifi", "cryptography-vectors (==43.0.0)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
test-randomorder = ["pytest-randomly"]
[[package]]
@ -1312,17 +1312,17 @@ django = ">=3"
[[package]]
name = "django-pglock"
version = "1.6.0"
version = "1.5.1"
description = "Postgres locking routines and lock table access."
optional = false
python-versions = "<4,>=3.8.0"
files = [
{file = "django_pglock-1.6.0-py3-none-any.whl", hash = "sha256:41c98d0bd3738d11e6eaefcc3e5146028f118a593ac58c13d663b751170f01de"},
{file = "django_pglock-1.6.0.tar.gz", hash = "sha256:724450ecc9886f39af599c477d84ad086545a5373215ef7a670cd25faca25a61"},
{file = "django_pglock-1.5.1-py3-none-any.whl", hash = "sha256:d3b977922abbaffd43968714b69cdab7453866adf2b0695fb497491748d7bc67"},
{file = "django_pglock-1.5.1.tar.gz", hash = "sha256:291903d5d877b68558003e1d64d764ebd5590344ba3b7aa1d5127df5947869b1"},
]
[package.dependencies]
django = ">=4"
django = ">=3"
django-pgactivity = ">=1.2,<2"
[[package]]

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "authentik"
version = "2024.8.4"
version = "2024.6.4"
description = ""
authors = ["authentik Team <hello@goauthentik.io>"]

View File

@ -1,7 +1,7 @@
openapi: 3.0.3
info:
title: authentik
version: 2024.8.4
version: 2024.6.4
description: Making authentication simple.
contact:
email: hello@goauthentik.io
@ -40457,11 +40457,10 @@ components:
items:
type: string
nullable: true
password_fields:
type: boolean
allow_show_password:
type: boolean
default: false
password_stage:
$ref: '#/components/schemas/PasswordChallenge'
captcha_stage:
$ref: '#/components/schemas/CaptchaChallenge'
application_pre:
type: string
flow_designation:
@ -40482,7 +40481,6 @@ components:
type: boolean
required:
- flow_designation
- password_fields
- primary_action
- show_source_labels
- user_fields
@ -40500,6 +40498,8 @@ components:
password:
type: string
nullable: true
captcha:
$ref: '#/components/schemas/CaptchaChallengeResponseRequest'
required:
- uid_field
IdentificationStage:
@ -40545,6 +40545,12 @@ components:
nullable: true
description: When set, shows a password field, instead of showing the password
field as separate step.
captcha_stage:
type: string
format: uuid
nullable: true
description: When set, the captcha element is shown on the identification
stage.
case_insensitive_matching:
type: boolean
description: When enabled, user fields are matched regardless of their casing.
@ -40613,6 +40619,12 @@ components:
nullable: true
description: When set, shows a password field, instead of showing the password
field as separate step.
captcha_stage:
type: string
format: uuid
nullable: true
description: When set, the captcha element is shown on the identification
stage.
case_insensitive_matching:
type: boolean
description: When enabled, user fields are matched regardless of their casing.
@ -45745,6 +45757,12 @@ components:
nullable: true
description: When set, shows a password field, instead of showing the password
field as separate step.
captcha_stage:
type: string
format: uuid
nullable: true
description: When set, the captcha element is shown on the identification
stage.
case_insensitive_matching:
type: boolean
description: When enabled, user fields are matched regardless of their casing.

View File

@ -11,7 +11,6 @@ from ldap3.core.exceptions import LDAPInvalidCredentialsResult
from authentik.blueprints.tests import apply_blueprint, reconcile_app
from authentik.core.models import Application, User
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow
from authentik.lib.generators import generate_id
@ -332,83 +331,6 @@ class TestProviderLDAP(SeleniumTestCase):
]
self.assert_list_dict_equal(expected, response)
@retry()
@apply_blueprint(
"default/flow-default-authentication-flow.yaml",
"default/flow-default-invalidation-flow.yaml",
)
@reconcile_app("authentik_tenants")
@reconcile_app("authentik_outposts")
def test_ldap_bind_search_no_perms(self):
"""Test simple bind + search"""
user = create_test_user()
self._prepare()
server = Server("ldap://localhost:3389", get_info=ALL)
_connection = Connection(
server,
raise_exceptions=True,
user=f"cn={user.username},ou=users,dc=ldap,dc=goauthentik,dc=io",
password=user.username,
)
_connection.bind()
self.assertTrue(
Event.objects.filter(
action=EventAction.LOGIN,
user={
"pk": user.pk,
"email": user.email,
"username": user.username,
},
)
)
_connection.search(
"ou=Users,DC=ldaP,dc=goauthentik,dc=io",
"(objectClass=user)",
search_scope=SUBTREE,
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
)
response: list = _connection.response
# Remove raw_attributes to make checking easier
for obj in response:
del obj["raw_attributes"]
del obj["raw_dn"]
obj["attributes"] = dict(obj["attributes"])
expected = [
{
"dn": f"cn={user.username},ou=users,dc=ldap,dc=goauthentik,dc=io",
"attributes": {
"cn": user.username,
"sAMAccountName": user.username,
"uid": user.uid,
"name": user.name,
"displayName": user.name,
"sn": user.name,
"mail": user.email,
"objectClass": [
"top",
"person",
"organizationalPerson",
"inetOrgPerson",
"user",
"posixAccount",
"goauthentik.io/ldap/user",
],
"uidNumber": 2000 + user.pk,
"gidNumber": 2000 + user.pk,
"memberOf": [
f"cn={group.name},ou=groups,dc=ldap,dc=goauthentik,dc=io"
for group in user.ak_groups.all()
],
"homeDirectory": f"/home/{user.username}",
"ak-active": True,
"ak-superuser": False,
},
"type": "searchResEntry",
},
]
self.assert_list_dict_equal(expected, response)
def assert_list_dict_equal(self, expected: list[dict], actual: list[dict], match_key="dn"):
"""Assert a list of dictionaries is identical, ignoring the ordering of items"""
self.assertEqual(len(expected), len(actual))

9533
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More