Compare commits

...

4 Commits

Author SHA1 Message Date
8128d8dab5 fix rac cache missing key
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-07-01 21:49:31 +02:00
f4a68c7878 use nested for RAC
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-07-01 21:46:35 +02:00
7ab17822e3 add support for nested routes
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-07-01 21:46:16 +02:00
76da77f26e fix ql schema
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-07-01 21:20:10 +02:00
14 changed files with 716 additions and 560 deletions

View File

@ -0,0 +1,67 @@
from rest_framework.routers import DefaultRouter as UpstreamDefaultRouter
from rest_framework.viewsets import ViewSet
from rest_framework_nested.routers import NestedMixin
class DefaultRouter(UpstreamDefaultRouter):
include_format_suffixes = False
class NestedRouter(DefaultRouter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.nested_routers = []
class nested:
def __init__(self, parent: "NestedRouter", prefix: str):
self.parent = parent
self.prefix = prefix
self.inner = None
def nested(self, lookup: str, prefix: str, viewset: type[ViewSet]):
if not self.inner:
self.inner = NestedDefaultRouter(self.parent, self.prefix, lookup=lookup)
self.inner.register(prefix, viewset)
return self
@property
def urls(self):
return self.parent.urls
def register(self, prefix, viewset, basename=None):
super().register(prefix, viewset, basename)
nested_router = self.nested(self, prefix)
self.nested_routers.append(nested_router)
return nested_router
def get_urls(self):
urls = super().get_urls()
for nested in self.nested_routers:
if not nested.inner:
continue
urls.extend(nested.inner.urls)
return urls
class NestedDefaultRouter(NestedMixin, DefaultRouter):
...
# def __init__(self, *args, **kwargs):
# self.args = args
# self.kwargs = kwargs
# self.routes = []
# def register(self, *args, **kwargs):
# self.routes.append((args, kwargs))
# @property
# def urls(self):
# class r(NestedMixin, DefaultRouter):
# ...
# router = r(*self.args, **self.kwargs)
# for route_args, route_kwrags in self.routes:
# router.register(*route_args, **route_kwrags)
# return router
root_router = DefaultRouter()

View File

@ -6,18 +6,15 @@ from django.urls import path
from django.urls.resolvers import URLPattern
from django.views.decorators.cache import cache_page
from drf_spectacular.views import SpectacularAPIView
from rest_framework import routers
from structlog.stdlib import get_logger
from authentik.api.v3.config import ConfigView
from authentik.api.v3.routers import root_router
from authentik.api.views import APIBrowserView
from authentik.lib.utils.reflection import get_apps
LOGGER = get_logger()
router = routers.DefaultRouter()
router.include_format_suffixes = False
_other_urls = []
for _authentik_app in get_apps():
try:
@ -38,7 +35,7 @@ for _authentik_app in get_apps():
if isinstance(url, URLPattern):
_other_urls.append(url)
else:
router.register(*url)
root_router.register(*url)
LOGGER.debug(
"Mounted API URLs",
app_name=_authentik_app.name,
@ -49,7 +46,7 @@ urlpatterns = (
[
path("", APIBrowserView.as_view(), name="schema-browser"),
]
+ router.urls
+ root_router.urls
+ _other_urls
+ [
path("root/config/", ConfigView.as_view(), name="config"),

View File

@ -6,7 +6,7 @@ from djangoql.ast import Name
from djangoql.exceptions import DjangoQLError
from djangoql.queryset import apply_search
from djangoql.schema import DjangoQLSchema
from rest_framework.filters import BaseFilterBackend, SearchFilter
from rest_framework.filters import SearchFilter
from rest_framework.request import Request
from structlog.stdlib import get_logger
@ -39,7 +39,8 @@ class BaseSchema(DjangoQLSchema):
return super().resolve_name(name)
class QLSearch(BaseFilterBackend):
# Inherits from SearchFilter to keep the schema correctly
class QLSearch(SearchFilter):
"""rest_framework search filter which uses DjangoQL"""
def __init__(self):

View File

@ -40,9 +40,16 @@ class ConnectionTokenViewSet(
):
"""ConnectionToken Viewset"""
queryset = ConnectionToken.objects.all().select_related("session", "endpoint")
queryset = ConnectionToken.objects.none()
serializer_class = ConnectionTokenSerializer
filterset_fields = ["endpoint", "session__user", "provider"]
search_fields = ["endpoint__name", "provider__name"]
ordering = ["endpoint__name", "provider__name"]
filterset_fields = ["endpoint", "session__user"]
search_fields = ["endpoint__name", "session__user__username"]
ordering = ["endpoint__name", "session__user__username"]
owner_field = "session__user"
def get_queryset(self):
return (
ConnectionToken.objects.all()
.select_related("session", "endpoint")
.filter(provider=self.kwargs["provider_pk"])
)

View File

@ -22,9 +22,9 @@ from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger()
def user_endpoint_cache_key(user_pk: str) -> str:
def user_endpoint_cache_key(user_pk: str, provider_pk: str) -> str:
"""Cache key where endpoint list for user is saved"""
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}/{provider_pk}"
class EndpointSerializer(ModelSerializer):
@ -65,12 +65,15 @@ class EndpointSerializer(ModelSerializer):
class EndpointViewSet(UsedByMixin, ModelViewSet):
"""Endpoint Viewset"""
queryset = Endpoint.objects.all()
queryset = Endpoint.objects.none()
serializer_class = EndpointSerializer
filterset_fields = ["name", "provider"]
filterset_fields = ["name"]
search_fields = ["name", "protocol"]
ordering = ["name", "protocol"]
def get_queryset(self):
return Endpoint.objects.filter(provider=self.kwargs["provider_pk"])
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
"""Custom filter_queryset method which ignores guardian, but still supports sorting"""
for backend in list(self.filter_backends):
@ -120,14 +123,11 @@ class EndpointViewSet(UsedByMixin, ModelViewSet):
if not should_cache:
allowed_endpoints = self._get_allowed_endpoints(queryset)
if should_cache:
allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
key = user_endpoint_cache_key(self.request.user.pk, self.kwargs["provider_pk"])
allowed_endpoints = cache.get(key)
if not allowed_endpoints:
LOGGER.debug("Caching allowed endpoint list")
allowed_endpoints = self._get_allowed_endpoints(queryset)
cache.set(
user_endpoint_cache_key(self.request.user.pk),
allowed_endpoints,
timeout=86400,
)
cache.set(key, allowed_endpoints, timeout=86400)
serializer = self.get_serializer(allowed_endpoints, many=True)
return self.get_paginated_response(serializer.data)

View File

@ -43,5 +43,5 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **
@receiver([post_save, post_delete], sender=Endpoint)
def post_save_post_delete_endpoint(**_):
"""Clear user's endpoint cache upon endpoint creation or deletion"""
keys = cache.keys(user_endpoint_cache_key("*"))
keys = cache.keys(user_endpoint_cache_key("*", "*"))
cache.delete_many(keys)

View File

@ -2,6 +2,7 @@
from django.urls import path
from authentik.api.v3.routers import NestedRouter
from authentik.outposts.channels import TokenOutpostMiddleware
from authentik.providers.rac.api.connection_tokens import ConnectionTokenViewSet
from authentik.providers.rac.api.endpoints import EndpointViewSet
@ -38,8 +39,10 @@ websocket_urlpatterns = [
]
api_urlpatterns = [
("providers/rac", RACProviderViewSet),
*NestedRouter()
.register("providers/rac", RACProviderViewSet)
.nested("provider", "endpoints", EndpointViewSet)
.nested("provider", "connection_tokens", ConnectionTokenViewSet)
.urls,
("propertymappings/provider/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet),
("rac/connection_tokens", ConnectionTokenViewSet),
]

View File

@ -28,6 +28,7 @@ dependencies = [
"djangorestframework-guardian==0.3.0",
"djangorestframework==3.16.0",
"docker==7.1.0",
"drf-nested-routers==0.94.2",
"drf-orjson-renderer==1.7.3",
"drf-spectacular==0.28.0",
"dumb-init==1.2.5.post1",

1092
schema.yml

File diff suppressed because it is too large Load Diff

15
uv.lock generated
View File

@ -191,6 +191,7 @@ dependencies = [
{ name = "djangorestframework" },
{ name = "djangorestframework-guardian" },
{ name = "docker" },
{ name = "drf-nested-routers" },
{ name = "drf-orjson-renderer" },
{ name = "drf-spectacular" },
{ name = "dumb-init" },
@ -290,6 +291,7 @@ requires-dist = [
{ name = "djangorestframework", git = "https://github.com/goauthentik/django-rest-framework?rev=896722bab969fabc74a08b827da59409cf9f1a4e" },
{ name = "djangorestframework-guardian", specifier = "==0.3.0" },
{ name = "docker", specifier = "==7.1.0" },
{ name = "drf-nested-routers", specifier = "==0.94.2" },
{ name = "drf-orjson-renderer", specifier = "==1.7.3" },
{ name = "drf-spectacular", specifier = "==0.28.0" },
{ name = "dumb-init", specifier = "==1.2.5.post1" },
@ -1190,6 +1192,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2f/71/1f500097efe09e04c3be862ab26c997314237a8b0a16dc3e3047fee23f4c/drf_jsonschema_serializer-3.0.0-py3-none-any.whl", hash = "sha256:d0e5cce095a5638b0bb7867aa060ed59ab9eed2f54ba5058dd9b483c9c887ed5", size = 8994, upload-time = "2024-06-26T13:09:59.929Z" },
]
[[package]]
name = "drf-nested-routers"
version = "0.94.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "django" },
{ name = "djangorestframework" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f6/98/2d29f3ecd337255bc2775b9addef347b6fd30ff7b3757649d0e50602ba08/drf_nested_routers-0.94.2.tar.gz", hash = "sha256:aa70923b716dc47cd93b8129b06be6c15706b405cf5f718f59cb8eed01de59cc", size = 22845, upload-time = "2025-05-14T17:03:50.896Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/62/dc/6bdb857a631fe6558db18a009c93ae16c3ad94fef0b7be7a3aa35c3264fa/drf_nested_routers-0.94.2-py2.py3-none-any.whl", hash = "sha256:74dbdceeae2a32f8668ba0df8e3eeabeb9b1c64d2621d914901ae653e4e3bcff", size = 36367, upload-time = "2025-05-14T17:03:49.257Z" },
]
[[package]]
name = "drf-orjson-renderer"
version = "1.7.3"

View File

@ -12,7 +12,7 @@ import { customElement, property } from "lit/decorators.js";
import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList/description-list.css";
import { ConnectionToken, RACProvider, RacApi } from "@goauthentik/api";
import { ConnectionToken, ProvidersApi, RACProvider } from "@goauthentik/api";
@customElement("ak-rac-connection-token-list")
export class ConnectionTokenListPage extends Table<ConnectionToken> {
@ -37,9 +37,9 @@ export class ConnectionTokenListPage extends Table<ConnectionToken> {
}
async apiEndpoint(): Promise<PaginatedResponse<ConnectionToken>> {
return new RacApi(DEFAULT_CONFIG).racConnectionTokensList({
return new ProvidersApi(DEFAULT_CONFIG).providersRacConnectionTokensList({
...(await this.defaultEndpointConfig()),
provider: this.provider?.pk,
providerPk: this.provider!.pk,
sessionUser: this.userId,
});
}
@ -56,12 +56,14 @@ export class ConnectionTokenListPage extends Table<ConnectionToken> {
];
}}
.usedBy=${(item: ConnectionToken) => {
return new RacApi(DEFAULT_CONFIG).racConnectionTokensUsedByList({
return new ProvidersApi(DEFAULT_CONFIG).providersRacConnectionTokensUsedByList({
providerPk: this.provider!.pk,
connectionTokenUuid: item.pk || "",
});
}}
.delete=${(item: ConnectionToken) => {
return new RacApi(DEFAULT_CONFIG).racConnectionTokensDestroy({
return new ProvidersApi(DEFAULT_CONFIG).providersRacConnectionTokensDestroy({
providerPk: this.provider!.pk,
connectionTokenUuid: item.pk || "",
});
}}

View File

@ -12,7 +12,7 @@ import { TemplateResult, html } from "lit";
import { customElement, property } from "lit/decorators.js";
import { ifDefined } from "lit/directives/if-defined.js";
import { AuthModeEnum, Endpoint, ProtocolEnum, RacApi } from "@goauthentik/api";
import { AuthModeEnum, Endpoint, ProtocolEnum, ProvidersApi } from "@goauthentik/api";
import { propertyMappingsProvider, propertyMappingsSelector } from "./RACProviderFormHelpers.js";
@ -22,7 +22,8 @@ export class EndpointForm extends ModelForm<Endpoint, string> {
providerID?: number;
loadInstance(pk: string): Promise<Endpoint> {
return new RacApi(DEFAULT_CONFIG).racEndpointsRetrieve({
return new ProvidersApi(DEFAULT_CONFIG).providersRacEndpointsRetrieve({
providerPk: this.providerID!,
pbmUuid: pk,
});
}
@ -41,12 +42,14 @@ export class EndpointForm extends ModelForm<Endpoint, string> {
data.provider = this.instance.provider;
}
if (this.instance) {
return new RacApi(DEFAULT_CONFIG).racEndpointsPartialUpdate({
return new ProvidersApi(DEFAULT_CONFIG).providersRacEndpointsPartialUpdate({
providerPk: this.providerID!,
pbmUuid: this.instance.pk || "",
patchedEndpointRequest: data,
});
}
return new RacApi(DEFAULT_CONFIG).racEndpointsCreate({
return new ProvidersApi(DEFAULT_CONFIG).providersRacEndpointsCreate({
providerPk: this.providerID!,
endpointRequest: data,
});
}

View File

@ -17,8 +17,8 @@ import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList
import {
Endpoint,
ProvidersApi,
RACProvider,
RacApi,
RbacPermissionsAssignedByUsersListModelEnum,
} from "@goauthentik/api";
@ -43,9 +43,9 @@ export class EndpointListPage extends Table<Endpoint> {
}
async apiEndpoint(): Promise<PaginatedResponse<Endpoint>> {
return new RacApi(DEFAULT_CONFIG).racEndpointsList({
return new ProvidersApi(DEFAULT_CONFIG).providersRacEndpointsList({
...(await this.defaultEndpointConfig()),
provider: this.provider?.pk,
providerPk: this.provider!.pk,
superuserFullList: true,
});
}
@ -70,12 +70,14 @@ export class EndpointListPage extends Table<Endpoint> {
];
}}
.usedBy=${(item: Endpoint) => {
return new RacApi(DEFAULT_CONFIG).racEndpointsUsedByList({
return new ProvidersApi(DEFAULT_CONFIG).providersRacEndpointsUsedByList({
providerPk: this.provider!.pk,
pbmUuid: item.pk,
});
}}
.delete=${(item: Endpoint) => {
return new RacApi(DEFAULT_CONFIG).racEndpointsDestroy({
return new ProvidersApi(DEFAULT_CONFIG).providersRacEndpointsDestroy({
providerPk: this.provider!.pk,
pbmUuid: item.pk,
});
}}

View File

@ -6,7 +6,7 @@ import { msg } from "@lit/localize";
import { TemplateResult, html } from "lit";
import { customElement, property } from "lit/decorators.js";
import { Application, Endpoint, RacApi } from "@goauthentik/api";
import { Application, Endpoint, ProvidersApi } from "@goauthentik/api";
@customElement("ak-library-rac-endpoint-launch")
export class RACLaunchEndpointModal extends TableModal<Endpoint> {
@ -30,9 +30,9 @@ export class RACLaunchEndpointModal extends TableModal<Endpoint> {
app?: Application;
async apiEndpoint(): Promise<PaginatedResponse<Endpoint>> {
const endpoints = await new RacApi(DEFAULT_CONFIG).racEndpointsList({
const endpoints = await new ProvidersApi(DEFAULT_CONFIG).providersRacEndpointsList({
...(await this.defaultEndpointConfig()),
provider: this.app?.provider || 0,
providerPk: this.app?.provider || 0,
});
if (this.open && endpoints.pagination.count === 1) {
this.clickHandler(endpoints.results[0]);