Compare commits

..

17 Commits

Author SHA1 Message Date
1dc4fbbb2b Attempting coercion to ESM with Vite & Wdio 2024-08-09 10:55:44 -07:00
3332de267d Testing needs to be able to import from dependent packages. 2024-08-09 10:53:20 -07:00
ab366d0ec2 Testing needs to be able to import from dependent packages. 2024-08-09 10:50:51 -07:00
ac162582aa Modernization continues. 2024-08-09 10:43:28 -07:00
7d82e029d5 wdio does not need the storybook cssimport hack. 2024-08-09 10:37:17 -07:00
9b40ecb023 web: restricted all eslints to local checks only and blocked them from cache analysis 2024-08-09 09:56:55 -07:00
0cc0fdaae3 Not ready for primetime. 2024-08-09 09:35:46 -07:00
b55b168718 web: move common into its own package.
```
$ mkdir ./packages/common
$ git mv ./src/common ./packages/common/src
```

... and then added all of the boilerplate needed to drive with Wireit, build with ESlint, typecheck
with TSC, and then spell check documentation and comments, security checks of package.json and
package-lock.json, format.

... and _then_ fix all of the minor, nitpicky things ESLint 9 found in the package.

... and _then_ wire the whole thing into our build so that we can find it as a package, removing
it as an alias from the base package definition and turning it into a workspace.  Although it is
a workspace package, it's currently configured to build completely independently.

It could be published as an independent NPM package, although I don't recommended that at this time.

I've wanted to break the UI up into smaller, more digestible chunks for awhile, but was always
reluctant to, since I didn't want to mess with other teams' mental models of the code layout.
@Beryju, seeing the success of the Simple Flow Executor as an independent package, thought it might
be worthwhile to see what effort it took to break the graph of our independent apps (User, Flow, and
Admin) and their dependencies (Common <- Elements <- Components, Common <- Locales) into packages.

Turns out, it's not too bad.  It's going to be fiddly for awhile until things settle down, but
overall the experiment has been a success.

The `tsconfig.json` doesn't refer to the base because we want this to build independently; tooling
will be needed to ensure all of our `tsconfig` files in the future will be consistent across all
packages.

- We can use the ESLint boilerplate as-is.
- We have to run TSC as a separate (but fortunately parallel) build step, as client code will need
  the built types. Final builds will be fractionally slower, but Wireit can detect when a monorepo
  package is unchanged and can skip rebuilding `common` if it's not needed, so the development loop
  will be faster.
- The ESBuild boilerplate is different for libraries with UI, libraries without UI (like this one),
  and apps, and we'll have to have three different routines for them. Once we are building
  independent _apps_, getting them into the `dist` folder will be an interesting challenge; we may
  end up with two different builds, one to bundle it in *in the app*, and another to bundle it *for
  Django*. That's mostly an issue of targeting and integration, and shouldn't take too much time.
- Spelling, formatting, and package checking aren't affected.
- `Locales` is our biggest challenge, as usual. I have found only [one article on it
  anywhere](https://medium.com/tech-at-zet/streamlining-localization-in-a-monorepo-using-i18n-js-e7c521ff69d4),
  and it recommends creating a single package in which to keep all of the localizations and the
  localization machinery. That seems like a sound approach, but we haven't (yet) gotten there.

`common` is a bit of a junk drawer: there are global utilities in there, there are app-specific
helpers, there are plug-in specific helpers, and so on. Figuring out exactly what does what and
making more specific packages may be in our future.
2024-08-09 08:38:30 -07:00
c46dc8f290 Not sure how that happened. 2024-08-08 16:10:07 -07:00
e48da3520c Fix type checking issues at the TSC level. 2024-08-08 16:03:37 -07:00
1ec4652c60 Fix dependent types needed before attempting typecheck. 2024-08-08 15:51:09 -07:00
e375646705 Made linting the subpackages a requirement of success. 2024-08-08 15:47:54 -07:00
b84652d9d3 Fix eslint so it only lints the local package. Other packages have their own responsibilities. 2024-08-08 15:44:43 -07:00
74b8da28ca Added common:build" to the list of dependencies for building, which is what you want, right? 2024-08-08 15:32:11 -07:00
9084c7c6b4 web: all the basic commands are working: build, build types, lint source, lint lockfile, lint packagefile, lint types, lint spelling, format source, format package. 2024-08-08 15:08:05 -07:00
7a0b227b46 Interim commit 2024-08-08 14:25:14 -07:00
cc9128fd46 Move begun; sfe cleanup completed. 2024-08-08 11:14:50 -07:00
639 changed files with 20036 additions and 42911 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 2024.8.4
current_version = 2024.6.3
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -29,9 +29,9 @@ outputs:
imageTags:
description: "Docker image tags"
value: ${{ steps.ev.outputs.imageTags }}
attestImageNames:
description: "Docker image names used for attestation"
value: ${{ steps.ev.outputs.attestImageNames }}
imageNames:
description: "Docker image names"
value: ${{ steps.ev.outputs.imageNames }}
imageMainTag:
description: "Docker image main tag"
value: ${{ steps.ev.outputs.imageMainTag }}

View File

@ -51,24 +51,15 @@ else:
]
image_main_tag = image_tags[0].split(":")[-1]
def get_attest_image_names(image_with_tags: list[str]):
"""Attestation only for GHCR"""
image_tags = []
for image_name in set(name.split(":")[0] for name in image_with_tags):
if not image_name.startswith("ghcr.io"):
continue
image_tags.append(image_name)
return ",".join(set(image_tags))
image_tags_rendered = ",".join(image_tags)
image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags))
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print(f"shouldBuild={should_build}", file=_output)
print(f"sha={sha}", file=_output)
print(f"version={version}", file=_output)
print(f"prerelease={prerelease}", file=_output)
print(f"imageTags={','.join(image_tags)}", file=_output)
print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output)
print(f"imageTags={image_tags_rendered}", file=_output)
print(f"imageNames={image_names_rendered}", file=_output)
print(f"imageMainTag={image_main_tag}", file=_output)
print(f"imageMainName={image_tags[0]}", file=_output)

View File

@ -58,10 +58,6 @@ updates:
patterns:
- "@rollup/*"
- "rollup-*"
swc:
patterns:
- "@swc/*"
- "swc-*"
wdio:
patterns:
- "@wdio/*"

View File

@ -261,7 +261,7 @@ jobs:
id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
pr-comment:

View File

@ -31,7 +31,7 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: latest
version: v1.54.2
args: --timeout 5000s --verbose
skip-cache: true
test-unittest:
@ -115,7 +115,7 @@ jobs:
id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-binary:

View File

@ -92,4 +92,4 @@ jobs:
run: make gen-client-ts
- name: test
working-directory: web/
run: npm run test || exit 0
run: npm run test

View File

@ -51,14 +51,12 @@ jobs:
secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }}
platforms: linux/amd64,linux/arm64
- uses: actions/attest-build-provenance@v1
id: attest
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-outpost:
@ -113,8 +111,6 @@ jobs:
id: push
with:
push: true
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }}
file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64
@ -122,7 +118,7 @@ jobs:
- uses: actions/attest-build-provenance@v1
id: attest
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-outpost-binary:

View File

@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1
# Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS website-builder
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as website-builder
ENV NODE_ENV=production
@ -20,7 +20,7 @@ COPY ./SECURITY.md /work/
RUN npm run build-bundled
# Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS web-builder
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as web-builder
ARG GIT_BUILD_HASH
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
@ -43,7 +43,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build
# Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.23-fips-bookworm AS go-builder
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
ARG TARGETOS
ARG TARGETARCH
@ -80,7 +80,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
go build -o /go/authentik ./cmd/server
# Stage 4: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 AS geoip
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
ENV GEOIPUPDATE_VERBOSE="1"
@ -94,10 +94,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 5: Python dependencies
FROM ghcr.io/goauthentik/fips-python:3.12.5-slim-bookworm-fips-full AS python-deps
ARG TARGETARCH
ARG TARGETVARIANT
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps
WORKDIR /ak-root/poetry
@ -124,17 +121,17 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
pip install --force-reinstall /wheels/*"
# Stage 6: Run
FROM ghcr.io/goauthentik/fips-python:3.12.5-slim-bookworm-fips-full AS final-image
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image
ARG VERSION
ARG GIT_BUILD_HASH
ARG VERSION
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
LABEL org.opencontainers.image.url=https://goauthentik.io
LABEL org.opencontainers.image.description="goauthentik.io Main server image, see https://goauthentik.io for more info."
LABEL org.opencontainers.image.source=https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version=${VERSION}
LABEL org.opencontainers.image.revision=${GIT_BUILD_HASH}
LABEL org.opencontainers.image.url https://goauthentik.io
LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info.
LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version ${VERSION}
LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
WORKDIR /

View File

@ -43,7 +43,7 @@ help: ## Show this help
sort
@echo ""
go-test:
test-go:
go test -timeout 0 -v -race -cover ./...
test-docker: ## Run all tests in a docker-compose
@ -210,9 +210,6 @@ web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting is
web-install: ## Install the necessary libraries to build the Authentik UI
cd web && npm ci
web-test: ## Run tests for the Authentik UI
cd web && npm run test
web-watch: ## Build and watch the Authentik UI for changes, updating automatically
rm -rf web/dist/
mkdir web/dist/

View File

@ -15,9 +15,7 @@
## What is authentik?
authentik is an open-source Identity Provider that emphasizes flexibility and versatility, with support for a wide set of protocols.
Our [enterprise offer](https://goauthentik.io/pricing) can also be used as a self-hosted replacement for large-scale deployments of Okta/Auth0, Entra ID, Ping Identity, or other legacy IdPs for employees and B2B2C use.
authentik is an open-source Identity Provider that emphasizes flexibility and versatility. It can be seamlessly integrated into existing environments to support new protocols. authentik is also a great solution for implementing sign-up, recovery, and other similar features in your application, saving you the hassle of dealing with them.
## Installation

View File

@ -2,7 +2,7 @@
from os import environ
__version__ = "2024.8.4"
__version__ = "2024.6.3"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -73,7 +73,7 @@ class SystemInfoSerializer(PassiveSerializer):
"authentik_version": get_full_version(),
"environment": get_env(),
"openssl_fips_enabled": (
backend._fips_enabled if LicenseKey.get_total().status().is_valid else None
backend._fips_enabled if LicenseKey.get_total().is_valid() else None
),
"openssl_version": OPENSSL_VERSION,
"platform": platform.platform(),

View File

@ -12,7 +12,6 @@ from rest_framework.views import APIView
from authentik import __version__, get_build_hash
from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version
from authentik.core.api.utils import PassiveSerializer
from authentik.outposts.models import Outpost
class VersionSerializer(PassiveSerializer):
@ -23,7 +22,6 @@ class VersionSerializer(PassiveSerializer):
version_latest_valid = SerializerMethodField()
build_hash = SerializerMethodField()
outdated = SerializerMethodField()
outpost_outdated = SerializerMethodField()
def get_build_hash(self, _) -> str:
"""Get build hash, if version is not latest or released"""
@ -49,15 +47,6 @@ class VersionSerializer(PassiveSerializer):
"""Check if we're running the latest version"""
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
def get_outpost_outdated(self, _) -> bool:
"""Check if any outpost is outdated/has a version mismatch"""
any_outdated = False
for outpost in Outpost.objects.all():
for state in outpost.state:
if state.version_outdated:
any_outdated = True
return any_outdated
class VersionView(APIView):
"""Get running and latest version."""

View File

@ -51,11 +51,9 @@ class BlueprintInstanceSerializer(ModelSerializer):
context = self.instance.context if self.instance else {}
valid, logs = Importer.from_string(content, context).validate()
if not valid:
text_logs = "\n".join([x["event"] for x in logs])
raise ValidationError(
[
_("Failed to validate blueprint"),
*[f"- {x.event}" for x in logs],
]
_("Failed to validate blueprint: {logs}".format_map({"logs": text_logs}))
)
return content

View File

@ -78,5 +78,5 @@ class TestBlueprintsV1API(APITestCase):
self.assertEqual(res.status_code, 400)
self.assertJSONEqual(
res.content.decode(),
{"content": ["Failed to validate blueprint", "- Invalid blueprint version"]},
{"content": ["Failed to validate blueprint: Invalid blueprint version"]},
)

View File

@ -171,7 +171,7 @@ class Importer:
def default_context(self):
"""Default context"""
return {
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid,
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid(),
"goauthentik.io/rbac/models": rbac_models(),
}
@ -429,7 +429,7 @@ class Importer:
orig_import = deepcopy(self._import)
if self._import.version != 1:
self.logger.warning("Invalid blueprint version")
return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)]
return False, [{"event": "Invalid blueprint version"}]
with (
transaction_rollback(),
capture_logs() as logs,

View File

@ -30,10 +30,8 @@ from authentik.core.api.utils import (
PassiveSerializer,
)
from authentik.core.expression.evaluator import PropertyMappingEvaluator
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group, PropertyMapping, User
from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.api.exec import PolicyTestSerializer
from authentik.rbac.decorators import permission_required
@ -164,15 +162,12 @@ class PropertyMappingViewSet(
response_data = {"successful": True, "result": ""}
try:
result = mapping.evaluate(dry_run=True, **context)
result = mapping.evaluate(**context)
response_data["result"] = dumps(
sanitize_item(result), indent=(4 if format_result else None)
)
except PropertyMappingExpressionException as exc:
response_data["result"] = exception_to_string(exc.exc)
response_data["successful"] = False
except Exception as exc:
response_data["result"] = exception_to_string(exc)
response_data["result"] = str(exc)
response_data["successful"] = False
response = PropertyMappingTestResultSerializer(response_data)
return Response(response.data)

View File

@ -14,7 +14,6 @@ from rest_framework.request import Request
from rest_framework.response import Response
from authentik.core.api.utils import PassiveSerializer
from authentik.rbac.filters import ObjectFilter
class DeleteAction(Enum):
@ -54,7 +53,7 @@ class UsedByMixin:
@extend_schema(
responses={200: UsedBySerializer(many=True)},
)
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
@action(detail=True, pagination_class=None, filter_backends=[])
def used_by(self, request: Request, *args, **kwargs) -> Response:
"""Get a list of all objects that use this object"""
model: Model = self.get_object()

View File

@ -678,13 +678,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
if not request.tenant.impersonation:
LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401)
user_to_be = self.get_object()
# Check both object-level perms and global perms
if not request.user.has_perm(
"authentik_core.impersonate", user_to_be
) and not request.user.has_perm("authentik_core.impersonate"):
if not request.user.has_perm("impersonate"):
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return Response(status=401)
user_to_be = self.get_object()
if user_to_be.pk == self.request.user.pk:
LOGGER.debug("User attempted to impersonate themselves", user=request.user)
return Response(status=401)

View File

@ -9,11 +9,10 @@ class Command(TenantCommand):
def add_arguments(self, parser):
parser.add_argument("--type", type=str, required=True)
parser.add_argument("--all", action="store_true", default=False)
parser.add_argument("usernames", nargs="*", type=str)
parser.add_argument("--all", action="store_true")
parser.add_argument("usernames", nargs="+", type=str)
def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"])
qs = (
User.objects.exclude_anonymous()
@ -23,9 +22,6 @@ class Command(TenantCommand):
if options["usernames"] and options["all"]:
self.stderr.write("--all and usernames specified, only one can be specified")
return
if not options["usernames"] and not options["all"]:
self.stderr.write("--all or usernames must be specified")
return
if options["usernames"] and not options["all"]:
qs = qs.filter(username__in=options["usernames"])
updated = qs.update(type=new_type)

View File

@ -466,6 +466,8 @@ class ApplicationQuerySet(QuerySet):
def with_provider(self) -> "QuerySet[Application]":
qs = self.select_related("provider")
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
if LOOKUP_SEP in subclass:
continue
qs = qs.select_related(f"provider__{subclass}")
return qs
@ -543,24 +545,15 @@ class Application(SerializerModel, PolicyBindingModel):
if not self.provider:
return None
candidates = []
base_class = Provider
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
parent = self.provider
for level in subclass.split(LOOKUP_SEP):
try:
parent = getattr(parent, level)
except AttributeError:
break
if parent in candidates:
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
# We don't care about recursion, skip nested models
if LOOKUP_SEP in subclass:
continue
idx = subclass.count(LOOKUP_SEP)
if type(parent) is not base_class:
idx += 1
candidates.insert(idx, parent)
if not candidates:
return None
return candidates[-1]
try:
return getattr(self.provider, subclass)
except AttributeError:
pass
return None
def __str__(self):
return str(self.name)
@ -908,7 +901,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
except ControlFlowException as exc:
raise exc
except Exception as exc:
raise PropertyMappingExpressionException(exc, self) from exc
raise PropertyMappingExpressionException(self, exc) from exc
def __str__(self):
return f"Property Mapping {self.name}"

View File

@ -9,12 +9,9 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.proxy.models import ProxyProvider
from authentik.providers.saml.models import SAMLProvider
class TestApplicationsAPI(APITestCase):
@ -225,31 +222,3 @@ class TestApplicationsAPI(APITestCase):
],
},
)
def test_get_provider(self):
"""Ensure that proxy providers (at the time of writing that is the only provider
that inherits from another proxy type (OAuth) instead of inheriting from the root
provider class) is correctly looked up and selected from the database"""
slug = generate_id()
provider = ProxyProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)
slug = generate_id()
provider = SAMLProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)

View File

@ -3,10 +3,10 @@
from json import loads
from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user, create_test_user
from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.tenants.utils import get_current_tenant
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
def setUp(self) -> None:
super().setUp()
self.other_user = create_test_user()
self.other_user = User.objects.create(username="to-impersonate")
self.user = create_test_admin_user()
def test_impersonate_simple(self):
@ -44,46 +44,6 @@ class TestImpersonation(APITestCase):
self.assertEqual(response_body["user"]["username"], self.user.username)
self.assertNotIn("original", response_body)
def test_impersonate_global(self):
"""Test impersonation with global permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user)
assign_perm("authentik_core.view_user", new_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_scoped(self):
"""Test impersonation with scoped permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user, self.other_user)
assign_perm("authentik_core.view_user", new_user, self.other_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_denied(self):
"""test impersonation without permissions"""
self.client.force_login(self.other_user)

View File

@ -35,7 +35,6 @@ from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event, EventAction
from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger()
@ -266,7 +265,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
],
responses={200: CertificateDataSerializer(many=False)},
)
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
@action(detail=True, pagination_class=None, filter_backends=[])
def view_certificate(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs certificate and log access"""
certificate: CertificateKeyPair = self.get_object()
@ -296,7 +295,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
],
responses={200: CertificateDataSerializer(many=False)},
)
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
@action(detail=True, pagination_class=None, filter_backends=[])
def view_private_key(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs private key and log access"""
certificate: CertificateKeyPair = self.get_object()

View File

@ -214,46 +214,6 @@ class TestCrypto(APITestCase):
self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response)
def test_certificate_download_denied(self):
"""Test certificate export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_private_key_download_denied(self):
"""Test private_key export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_used_by(self):
"""Test used_by endpoint"""
self.client.force_login(create_test_admin_user())
@ -286,26 +246,6 @@ class TestCrypto(APITestCase):
],
)
def test_used_by_denied(self):
"""Test used_by endpoint"""
self.client.logout()
keypair = create_test_cert()
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://localhost",
signing_key=keypair,
)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-used-by",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
def test_discovery(self):
"""Test certificate discovery"""
name = generate_id()

View File

@ -1,11 +1,12 @@
"""Enterprise API Views"""
from dataclasses import asdict
from datetime import timedelta
from django.utils.timezone import now
from django.utils.translation import gettext as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, inline_serializer
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, IntegerField
@ -29,7 +30,7 @@ class EnterpriseRequiredMixin:
def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists"""
if not LicenseKey.cached_summary().status.is_valid:
if not LicenseKey.cached_summary().has_license:
raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs)
@ -86,7 +87,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
},
)
@action(detail=False, methods=["GET"])
def install_id(self, request: Request) -> Response:
def get_install_id(self, request: Request) -> Response:
"""Get install_id"""
return Response(
data={
@ -99,22 +100,12 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
responses={
200: LicenseSummarySerializer(),
},
parameters=[
OpenApiParameter(
name="cached",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.BOOL,
default=True,
)
],
)
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
def summary(self, request: Request) -> Response:
"""Get the total license status"""
summary = LicenseKey.cached_summary()
if request.query_params.get("cached", "true").lower() == "false":
summary = LicenseKey.get_total().summary()
response = LicenseSummarySerializer(instance=summary)
response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
response.is_valid(raise_exception=True)
return Response(response.data)
@permission_required(None, ["authentik_enterprise.view_license"])
@ -137,7 +128,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
forecast_for_months = 12
response = LicenseForecastSerializer(
data={
"internal_users": LicenseKey.get_internal_user_count(),
"internal_users": LicenseKey.get_default_user_count(),
"external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months),

View File

@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
"""Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.cached_summary().status.is_valid
return LicenseKey.cached_summary().valid

View File

@ -3,37 +3,24 @@
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import DaciteError, from_dict
from dacite import from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.fields import (
ChoiceField,
DateTimeField,
IntegerField,
ListField,
)
from rest_framework.fields import BooleanField, DateTimeField, IntegerField
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_EXPIRY_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.enterprise.models import License, LicenseUsage
from authentik.tenants.utils import get_unique_identifier
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
@ -55,9 +42,6 @@ def get_license_aud() -> str:
class LicenseFlags(Enum):
"""License flags"""
TRIAL = "trial"
NON_PRODUCTION = "non_production"
@dataclass
class LicenseSummary:
@ -65,9 +49,12 @@ class LicenseSummary:
internal_users: int
external_users: int
status: LicenseUsageStatus
valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
latest_valid: datetime
license_flags: list[LicenseFlags]
has_license: bool
class LicenseSummarySerializer(PassiveSerializer):
@ -75,9 +62,12 @@ class LicenseSummarySerializer(PassiveSerializer):
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
status = ChoiceField(choices=LicenseUsageStatus.choices)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField()
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags)))
has_license = BooleanField()
@dataclass
@ -90,10 +80,10 @@ class LicenseKey:
name: str
internal_users: int = 0
external_users: int = 0
license_flags: list[LicenseFlags] = field(default_factory=list)
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
@ -117,28 +107,26 @@ class LicenseKey:
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
),
)
except PyJWTError:
unverified = decode(jwt, options={"verify_signature": False})
if unverified["aud"] != get_license_aud():
raise ValidationError("Invalid Install ID in license") from None
raise ValidationError("Unable to verify license") from None
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in License.objects.all():
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
total.exp = max(total.exp, exp_ts)
total.license_flags.extend(lic.status.license_flags)
if exp_ts <= total.exp:
total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total
@staticmethod
@ -147,7 +135,7 @@ class LicenseKey:
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_internal_user_count():
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@ -156,73 +144,59 @@ class LicenseKey:
"""Get current external user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()
def _last_valid_date(self):
last_valid_date = (
LicenseUsage.objects.order_by("-record_date")
.filter(status=LicenseUsageStatus.VALID)
.first()
)
if not last_valid_date:
return datetime.fromtimestamp(0, UTC)
return last_valid_date.record_date
def is_valid(self) -> bool:
"""Check if the given license body covers all users
def status(self) -> LicenseUsageStatus:
"""Check if the given license body covers all users, and is valid."""
last_valid = self._last_valid_date()
if self.exp == 0 and not License.objects.exists():
return LicenseUsageStatus.UNLICENSED
_now = now()
# Check limit-exceeded based status
internal_users = self.get_internal_user_count()
external_users = self.get_external_user_count()
if internal_users > self.internal_users or external_users > self.external_users:
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
return LicenseUsageStatus.READ_ONLY
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
# Check expiry based status
if datetime.fromtimestamp(self.exp, UTC) < _now:
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
weeks=THRESHOLD_READ_ONLY_WEEKS
):
return LicenseUsageStatus.READ_ONLY
return LicenseUsageStatus.EXPIRED
# Expiry warning
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
):
return LicenseUsageStatus.EXPIRY_SOON
return LicenseUsageStatus.VALID
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
usage = (
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first()
)
if not usage:
usage = LicenseUsage.objects.create(
internal_user_count=self.get_internal_user_count(),
if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),
status=self.status(),
within_limits=self.is_valid(),
)
summary = asdict(self.summary())
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return usage
return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
def summary(self) -> LicenseSummary:
"""Summary of license status"""
status = self.status()
has_license = License.objects.all().count() > 0
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary(
show_admin_warning=show_admin_warning and has_license,
show_user_warning=show_user_warning and has_license,
read_only=read_only and has_license,
latest_valid=latest_valid,
internal_users=self.internal_users,
external_users=self.external_users,
status=status,
license_flags=self.license_flags,
valid=self.is_valid(),
has_license=has_license,
)
@staticmethod
@ -231,8 +205,4 @@ class LicenseKey:
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary:
return LicenseKey.get_total().summary()
try:
return from_dict(LicenseSummary, summary)
except DaciteError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()
return from_dict(LicenseSummary, summary)

View File

@ -8,7 +8,6 @@ from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsageStatus
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path
@ -44,7 +43,7 @@ class EnterpriseMiddleware:
cached_status = LicenseKey.cached_summary()
if not cached_status:
return True
if cached_status.status == LicenseUsageStatus.READ_ONLY:
if cached_status.read_only:
return False
return True
@ -54,10 +53,10 @@ class EnterpriseMiddleware:
if request.method.lower() in ["get", "head", "options", "trace"]:
return True
# Always allow requests to manage licenses
if request.resolver_match._func_path == class_to_path(LicenseViewSet):
if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
return True
# Flow executor is mounted as an API path but explicitly allowed
if request.resolver_match._func_path == class_to_path(FlowExecutorView):
if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
return True
# Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names:

View File

@ -1,68 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-08 14:15
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage")
db_alias = schema_editor.connection.alias
for usage in LicenseUsage.objects.using(db_alias).all():
usage.status = "valid" if usage.within_limits else "limit_exceeded_admin"
usage.save()
class Migration(migrations.Migration):
dependencies = [
("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"),
]
operations = [
migrations.AddField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
default=None,
null=True,
),
preserve_default=False,
),
migrations.RunPython(migrate_license_usage),
migrations.RemoveField(
model_name="licenseusage",
name="within_limits",
),
migrations.AlterField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
),
preserve_default=False,
),
migrations.RenameField(
model_name="licenseusage",
old_name="user_count",
new_name="internal_user_count",
),
]

View File

@ -17,17 +17,6 @@ if TYPE_CHECKING:
from authentik.enterprise.license import LicenseKey
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
THRESHOLD_WARNING_ADMIN_WEEKS = 2
THRESHOLD_WARNING_USER_WEEKS = 4
THRESHOLD_WARNING_EXPIRY_WEEKS = 2
THRESHOLD_READ_ONLY_WEEKS = 6
class License(SerializerModel):
"""An authentik enterprise license"""
@ -50,7 +39,7 @@ class License(SerializerModel):
"""Get parsed license status"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key, check_expiry=False)
return LicenseKey.validate(self.key)
class Meta:
indexes = (HashIndex(fields=("key",)),)
@ -58,23 +47,9 @@ class License(SerializerModel):
verbose_name_plural = _("Licenses")
class LicenseUsageStatus(models.TextChoices):
"""License states an instance/tenant can be in"""
UNLICENSED = "unlicensed"
VALID = "valid"
EXPIRED = "expired"
EXPIRY_SOON = "expiry_soon"
# User limit exceeded, 2 week threshold, show message in admin interface
LIMIT_EXCEEDED_ADMIN = "limit_exceeded_admin"
# User limit exceeded, 4 week threshold, show message in user interface
LIMIT_EXCEEDED_USER = "limit_exceeded_user"
READ_ONLY = "read_only"
@property
def is_valid(self) -> bool:
"""Quickly check if a license is valid"""
return self in [LicenseUsageStatus.VALID, LicenseUsageStatus.EXPIRY_SOON]
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
class LicenseUsage(ExpiringModel):
@ -84,9 +59,9 @@ class LicenseUsage(ExpiringModel):
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
internal_user_count = models.BigIntegerField()
user_count = models.BigIntegerField()
external_user_count = models.BigIntegerField()
status = models.TextField(choices=LicenseUsageStatus.choices)
within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True)

View File

@ -13,7 +13,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
def check_license(self):
"""Check license"""
if not LicenseKey.get_total().status().is_valid:
if not LicenseKey.get_total().is_valid():
return PolicyResult(False, _("Enterprise required to access this feature."))
if self.request.user.type != UserTypes.INTERNAL:
return PolicyResult(False, _("Feature only accessible for internal users."))

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import (
google_workspace_sync,
google_workspace_sync_objects,
)
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -55,4 +52,3 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
search_fields = ["name"]
ordering = ["name"]
sync_single_task = google_workspace_sync
sync_objects_task = google_workspace_sync_objects

View File

@ -181,7 +181,7 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-provider-google-workspace-form"
return "ak-property-mapping-google-workspace-form"
@property
def serializer(self) -> type[Serializer]:

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import (
microsoft_entra_sync,
microsoft_entra_sync_objects,
)
from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -53,4 +50,3 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
search_fields = ["name"]
ordering = ["name"]
sync_single_task = microsoft_entra_sync
sync_objects_task = microsoft_entra_sync_objects

View File

@ -170,7 +170,7 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-provider-microsoft-entra-form"
return "ak-property-mapping-microsoft-entra-form"
@property
def serializer(self) -> type[Serializer]:

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_rac", "0004_alter_connectiontoken_expires"),
]
operations = [
migrations.AlterModelOptions(
name="racpropertymapping",
options={
"verbose_name": "RAC Provider Property Mapping",
"verbose_name_plural": "RAC Provider Property Mappings",
},
),
]

View File

@ -125,7 +125,7 @@ class RACPropertyMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-provider-rac-form"
return "ak-property-mapping-rac-form"
@property
def serializer(self) -> type[Serializer]:
@ -136,8 +136,8 @@ class RACPropertyMapping(PropertyMapping):
return RACPropertyMappingSerializer
class Meta:
verbose_name = _("RAC Provider Property Mapping")
verbose_name_plural = _("RAC Provider Property Mappings")
verbose_name = _("RAC Property Mapping")
verbose_name_plural = _("RAC Property Mappings")
class ConnectionToken(ExpiringModel):

View File

@ -44,7 +44,7 @@ websocket_urlpatterns = [
api_urlpatterns = [
("providers/rac", RACProviderViewSet),
("propertymappings/provider/rac", RACPropertyMappingViewSet),
("propertymappings/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet),
("rac/connection_tokens", ConnectionTokenViewSet),
]

View File

@ -3,7 +3,7 @@
from datetime import datetime
from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_save
from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver
from django.utils.timezone import get_current_timezone
@ -27,9 +27,3 @@ def post_save_license(sender: type[License], instance: License, **_):
"""Trigger license usage calculation when license is saved"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
enterprise_update_usage.delay()
@receiver(post_delete, sender=License)
def post_delete_license(sender: type[License], instance: License, **_):
"""Clear license cache when license is deleted"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)

View File

@ -9,26 +9,10 @@ from django.utils.timezone import now
from rest_framework.exceptions import ValidationError
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.enterprise.models import License
from authentik.lib.generators import generate_id
# Valid license expiry
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry, expires soon
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
# Invalid license expiry, recently expired
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
# Invalid license expiry, expired longer ago
expiry_expired_read_only = int(
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
)
_exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
class TestEnterpriseLicense(TestCase):
@ -39,7 +23,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
exp=_exp,
name=generate_id(),
internal_users=100,
external_users=100,
@ -49,7 +33,7 @@ class TestEnterpriseLicense(TestCase):
def test_valid(self):
"""Check license verification"""
lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid)
self.assertTrue(lic.status.is_valid())
self.assertEqual(lic.internal_users, 100)
def test_invalid(self):
@ -62,7 +46,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
exp=_exp,
name=generate_id(),
internal_users=100,
external_users=100,
@ -72,186 +56,11 @@ class TestEnterpriseLicense(TestCase):
def test_valid_multiple(self):
"""Check license verification"""
lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid)
self.assertTrue(lic.status.is_valid())
lic2 = License.objects.create(key=generate_id())
self.assertTrue(lic2.status.status().is_valid)
self.assertTrue(lic2.status.is_valid())
total = LicenseKey.get_total()
self.assertEqual(total.internal_users, 200)
self.assertEqual(total.external_users, 200)
self.assertEqual(total.exp, expiry_valid)
self.assertTrue(total.status().is_valid)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_user_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_admin_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired_read_only,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_expired(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_soon,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_soon(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON)
self.assertEqual(total.exp, _exp)
self.assertTrue(total.is_valid())

View File

@ -1,217 +0,0 @@
"""read only tests"""
from datetime import timedelta
from unittest.mock import MagicMock, patch
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.flows.models import (
FlowDesignation,
FlowStageBinding,
)
from authentik.flows.tests import FlowTestCase
from authentik.lib.generators import generate_id
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.password import BACKEND_INBUILT
from authentik.stages.password.models import PasswordStage
from authentik.stages.user_login.models import UserLoginStage
class TestReadOnly(FlowTestCase):
"""Test read_only"""
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_login(self):
"""Test flow, ensure login is still possible with read only mode"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
flow = create_test_flow(
FlowDesignation.AUTHENTICATION,
)
ident_stage = IdentificationStage.objects.create(
name=generate_id(),
user_fields=[UserFields.E_MAIL],
pretend_user_exists=False,
)
FlowStageBinding.objects.create(
target=flow,
stage=ident_stage,
order=0,
)
password_stage = PasswordStage.objects.create(
name=generate_id(), backends=[BACKEND_INBUILT]
)
FlowStageBinding.objects.create(
target=flow,
stage=password_stage,
order=1,
)
login_stage = UserLoginStage.objects.create(
name=generate_id(),
)
FlowStageBinding.objects.create(
target=flow,
stage=login_stage,
order=2,
)
user = create_test_user()
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(exec_url)
self.assertStageResponse(
response,
flow,
component="ak-stage-identification",
password_fields=False,
primary_action="Log in",
sources=[],
show_source_labels=False,
user_fields=[UserFields.E_MAIL],
)
response = self.client.post(exec_url, {"uid_field": user.email}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-password")
response = self.client.post(exec_url, {"password": user.username}, follow=True)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_licenses(self):
"""Test that managing licenses is still possible"""
license = License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Reading is always allowed
response = self.client.get(reverse("authentik_api:license-list"))
self.assertEqual(response.status_code, 200)
# Writing should also be allowed
response = self.client.patch(
reverse("authentik_api:license-detail", kwargs={"pk": license.pk})
)
self.assertEqual(response.status_code, 200)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_flows(self):
"""Test flow"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Read only is still allowed
response = self.client.get(reverse("authentik_api:flow-list"))
self.assertEqual(response.status_code, 200)
flow = create_test_flow()
# Writing is not
response = self.client.patch(
reverse("authentik_api:flow-detail", kwargs={"slug": flow.slug})
)
self.assertJSONEqual(
response.content,
{"detail": "Request denied due to expired/invalid license.", "code": "denied_license"},
)
self.assertEqual(response.status_code, 400)

View File

@ -69,5 +69,8 @@ class NotificationViewSet(
@action(detail=False, methods=["post"])
def mark_all_seen(self, request: Request) -> Response:
"""Mark all the user's notifications as seen"""
Notification.objects.filter(user=request.user, seen=False).update(seen=True)
notifications = Notification.objects.filter(user=request.user)
for notification in notifications:
notification.seen = True
Notification.objects.bulk_update(notifications, ["seen"])
return Response({}, status=204)

View File

@ -49,7 +49,6 @@ from authentik.policies.models import PolicyBindingModel
from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
DISCORD_FIELD_LIMIT = 25
@ -59,11 +58,7 @@ NOTIFICATION_SUMMARY_LENGTH = 75
def default_event_duration():
"""Default duration an Event is saved.
This is used as a fallback when no brand is available"""
try:
tenant = get_current_tenant()
return now() + timedelta_from_string(tenant.event_retention)
except Tenant.DoesNotExist:
return now() + timedelta(days=365)
return now() + timedelta(days=365)
def default_brand():
@ -250,6 +245,12 @@ class Event(SerializerModel, ExpiringModel):
if QS_QUERY in self.context["http_request"]["args"]:
wrapped = self.context["http_request"]["args"][QS_QUERY]
self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
if hasattr(request, "tenant"):
tenant: Tenant = request.tenant
# Because self.created only gets set on save, we can't use it's value here
# hence we set self.created to now and then use it
self.created = now()
self.expires = self.created + timedelta_from_string(tenant.event_retention)
if hasattr(request, "brand"):
brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand))

View File

@ -6,7 +6,6 @@ from django.db.models import Model
from django.test import TestCase
from authentik.core.models import default_token_key
from authentik.events.models import default_event_duration
from authentik.lib.utils.reflection import get_apps
@ -21,7 +20,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable:
allowed = 0
# Token-like objects need to lookup the current tenant to get the default token length
for field in test_model._meta.fields:
if field.default in [default_token_key, default_event_duration]:
if field.default == default_token_key:
allowed += 1
with self.assertNumQueries(allowed):
str(test_model())

View File

@ -2,8 +2,7 @@
from unittest.mock import MagicMock, patch
from django.urls import reverse
from rest_framework.test import APITestCase
from django.test import TestCase
from authentik.core.models import Group, User
from authentik.events.models import (
@ -11,7 +10,6 @@ from authentik.events.models import (
EventAction,
Notification,
NotificationRule,
NotificationSeverity,
NotificationTransport,
NotificationWebhookMapping,
TransportMode,
@ -22,7 +20,7 @@ from authentik.policies.exceptions import PolicyException
from authentik.policies.models import PolicyBinding
class TestEventsNotifications(APITestCase):
class TestEventsNotifications(TestCase):
"""Test Event Notifications"""
def setUp(self) -> None:
@ -133,15 +131,3 @@ class TestEventsNotifications(APITestCase):
Notification.objects.all().delete()
Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(Notification.objects.first().body, "foo")
def test_api_mark_all_seen(self):
"""Test mark_all_seen"""
self.client.force_login(self.user)
Notification.objects.create(
severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
)
response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
self.assertEqual(response.status_code, 204)
self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())

View File

@ -37,7 +37,6 @@ from authentik.lib.utils.file import (
)
from authentik.lib.views import bad_request_message
from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger()
@ -282,7 +281,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
400: OpenApiResponse(description="Flow not applicable"),
},
)
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
@action(detail=True, pagination_class=None, filter_backends=[])
def execute(self, request: Request, slug: str):
"""Execute flow for current user"""
# Because we pre-plan the flow here, and not in the planner, we need to manually clear

View File

@ -2,6 +2,7 @@
import re
import socket
from collections.abc import Iterable
from ipaddress import ip_address, ip_network
from textwrap import indent
from types import CodeType
@ -27,12 +28,6 @@ from authentik.stages.authenticator import devices_for_user
LOGGER = get_logger()
ARG_SANITIZE = re.compile(r"[:.-]")
def sanitize_arg(arg_name: str) -> str:
return re.sub(ARG_SANITIZE, "_", arg_name)
class BaseEvaluator:
"""Validate and evaluate python-based expressions"""
@ -182,9 +177,9 @@ class BaseEvaluator:
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
return proc.profiling_wrapper()
def wrap_expression(self, expression: str) -> str:
def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
"""Wrap expression in a function, call it, and save the result as `result`"""
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
handler_signature = ",".join(params)
full_expression = ""
full_expression += f"def handler({handler_signature}):\n"
full_expression += indent(expression, " ")
@ -193,8 +188,8 @@ class BaseEvaluator:
def compile(self, expression: str) -> CodeType:
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
expression = self.wrap_expression(expression)
return compile(expression, self._filename, "exec")
param_keys = self._context.keys()
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
def evaluate(self, expression_source: str) -> Any:
"""Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
@ -210,7 +205,7 @@ class BaseEvaluator:
self.handle_error(exc, expression_source)
raise exc
try:
_locals = {sanitize_arg(x): y for x, y in self._context.items()}
_locals = self._context
# Yes this is an exec, yes it is potentially bad. Since we limit what variables are
# available here, and these policies can only be edited by admins, this is a risk
# we're willing to take.

View File

@ -1,19 +1,16 @@
from celery import Task
from collections.abc import Callable
from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ChoiceField
from rest_framework.fields import BooleanField
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Group, User
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.events.logs import LogEvent, LogEventSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
from authentik.rbac.filters import ObjectFilter
class SyncStatusSerializer(PassiveSerializer):
@ -23,29 +20,10 @@ class SyncStatusSerializer(PassiveSerializer):
tasks = SystemTaskSerializer(many=True, read_only=True)
class SyncObjectSerializer(PassiveSerializer):
"""Sync object serializer"""
sync_object_model = ChoiceField(
choices=(
(class_to_path(User), "user"),
(class_to_path(Group), "group"),
)
)
sync_object_id = CharField()
class SyncObjectResultSerializer(PassiveSerializer):
"""Result of a single object sync"""
messages = LogEventSerializer(many=True, read_only=True)
class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers"""
sync_single_task: type[Task] = None
sync_objects_task: type[Task] = None
sync_single_task: Callable = None
@extend_schema(
responses={
@ -58,7 +36,7 @@ class OutgoingSyncProviderStatusMixin:
detail=True,
pagination_class=None,
url_path="sync/status",
filter_backends=[ObjectFilter],
filter_backends=[],
)
def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status"""
@ -77,30 +55,6 @@ class OutgoingSyncProviderStatusMixin:
}
return Response(SyncStatusSerializer(status).data)
@extend_schema(
request=SyncObjectSerializer,
responses={200: SyncObjectResultSerializer()},
)
@action(
methods=["POST"],
detail=True,
pagination_class=None,
url_path="sync/object",
filter_backends=[ObjectFilter],
)
def sync_object(self, request: Request, pk: int) -> Response:
"""Sync/Re-sync a single user/group object"""
provider: OutgoingSyncProvider = self.get_object()
params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True)
res: list[LogEvent] = self.sync_objects_task.delay(
params.validated_data["sync_object_model"],
page=1,
provider_pk=provider.pk,
pk=params.validated_data["sync_object_id"],
).get()
return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
class OutgoingSyncConnectionCreateMixin:
"""Mixin for connection objects that fetches remote data upon creation"""

View File

@ -105,7 +105,7 @@ class SyncTasks:
return
task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects(self, object_type: str, page: int, provider_pk: int, **filter):
def sync_objects(self, object_type: str, page: int, provider_pk: int):
_object_type = path_to_class(object_type)
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
@ -120,7 +120,7 @@ class SyncTasks:
client = provider.client_for_model(_object_type)
except TransientSyncException:
return messages
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
if client.can_discover:
self.logger.debug("starting discover")
client.discover()

View File

@ -30,11 +30,6 @@ class TestHTTP(TestCase):
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
def test_forward_for_invalid(self):
"""Test invalid forward for"""
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip)
def test_fake_outpost(self):
"""Test faked IP which is overridden by an outpost"""
token = Token.objects.create(
@ -58,17 +53,6 @@ class TestHTTP(TestCase):
},
)
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
# Invalid, not a real IP
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
self.user.save()
request = self.factory.get(
"/",
**{
ClientIPMiddleware.outpost_remote_ip_header: "foobar",
ClientIPMiddleware.outpost_token_header: token.key,
},
)
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
# Valid
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
self.user.save()

View File

@ -21,14 +21,7 @@ class DebugSession(Session):
def send(self, req: PreparedRequest, *args, **kwargs):
request_id = str(uuid4())
LOGGER.debug(
"HTTP request sent",
uid=request_id,
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
)
LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
resp = super().send(req, *args, **kwargs)
LOGGER.debug(
"HTTP response received",

View File

@ -26,6 +26,7 @@ from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
from authentik.outposts.models import (
Outpost,
OutpostConfig,
OutpostState,
OutpostType,
default_outpost_config,
)
@ -139,7 +140,7 @@ class OutpostHealthSerializer(PassiveSerializer):
def get_fips_enabled(self, obj: dict) -> bool | None:
"""Get FIPS enabled"""
if not LicenseKey.get_total().status().is_valid:
if not LicenseKey.get_total().is_valid():
return None
return obj["fips_enabled"]
@ -181,6 +182,7 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
outpost: Outpost = self.get_object()
states = []
for state in outpost.state:
state: OutpostState
states.append(
{
"uid": state.uid,

View File

@ -26,7 +26,6 @@ from authentik.outposts.models import (
KubernetesServiceConnection,
OutpostServiceConnection,
)
from authentik.rbac.filters import ObjectFilter
class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer):
@ -76,7 +75,7 @@ class ServiceConnectionViewSet(
filterset_fields = ["name"]
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
@action(detail=True, pagination_class=None, filter_backends=[])
def state(self, request: Request, pk: str) -> Response:
"""Get the service connection's state"""
connection = self.get_object()

View File

@ -451,7 +451,7 @@ class OutpostState:
return False
if self.build_hash != get_build_hash():
return False
return parse(self.version) != OUR_VERSION
return parse(self.version) < OUR_VERSION
@staticmethod
def for_outpost(outpost: Outpost) -> list["OutpostState"]:

View File

@ -214,7 +214,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
if not hasattr(instance, field_name):
continue
LOGGER.debug("triggering outpost update from field", field=field.name)
LOGGER.debug("triggering outpost update from from field", field=field.name)
# Because the Outpost Model has an M2M to Provider,
# we have to iterate over the entire QS
for reverse in getattr(instance, field_name).all():

View File

@ -108,7 +108,7 @@ class EventMatcherPolicy(Policy):
result=result,
)
matches.append(result)
passing = all(x.passing for x in matches)
passing = any(x.passing for x in matches)
messages = chain(*[x.messages for x in matches])
result = PolicyResult(passing, *messages)
result.source_results = matches

View File

@ -77,24 +77,11 @@ class TestEventMatcherPolicy(TestCase):
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.5", app="foo"
client_ip="1.2.3.5", app="bar"
)
response = policy.passes(request)
self.assertFalse(response.passing)
def test_multiple(self):
"""Test multiple"""
event = Event.new(EventAction.LOGIN)
event.app = "foo"
event.client_ip = "1.2.3.4"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.4", app="foo"
)
response = policy.passes(request)
self.assertTrue(response.passing)
def test_invalid(self):
"""Test passing event"""
request = PolicyRequest(get_anonymous_user())

View File

@ -36,7 +36,7 @@ def update_score(request: HttpRequest, identifier: str, amount: int):
if not created:
reputation.score = F("score") + amount
reputation.save()
LOGGER.info("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
@receiver(login_failed)

View File

@ -2,25 +2,15 @@
from django.db.models import QuerySet
from django.db.models.query import Q
from django.shortcuts import get_object_or_404
from django_filters.filters import BooleanFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ListField, SerializerMethodField
from rest_framework.fields import CharField, ListField, SerializerMethodField
from rest_framework.mixins import ListModelMixin
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Application
from authentik.policies.api.exec import PolicyTestResultSerializer
from authentik.policies.engine import PolicyEngine
from authentik.policies.types import PolicyResult
from authentik.core.api.utils import ModelSerializer
from authentik.providers.ldap.models import LDAPProvider
@ -33,6 +23,7 @@ class LDAPProviderSerializer(ProviderSerializer):
model = LDAPProvider
fields = ProviderSerializer.Meta.fields + [
"base_dn",
"search_group",
"certificate",
"tls_server_name",
"uid_start_number",
@ -64,6 +55,8 @@ class LDAPProviderFilter(FilterSet):
"name": ["iexact"],
"authorization_flow__slug": ["iexact"],
"base_dn": ["iexact"],
"search_group__group_uuid": ["iexact"],
"search_group__name": ["iexact"],
"certificate__kp_uuid": ["iexact"],
"certificate__name": ["iexact"],
"tls_server_name": ["iexact"],
@ -102,6 +95,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
"base_dn",
"bind_flow_slug",
"application_slug",
"search_group",
"certificate",
"tls_server_name",
"uid_start_number",
@ -122,33 +116,3 @@ class LDAPOutpostConfigViewSet(ListModelMixin, GenericViewSet):
ordering = ["name"]
search_fields = ["name"]
filterset_fields = ["name"]
class LDAPCheckAccessSerializer(PassiveSerializer):
has_search_permission = BooleanField(required=False)
access = PolicyTestResultSerializer()
@extend_schema(
request=None,
parameters=[OpenApiParameter("app_slug", OpenApiTypes.STR)],
responses={
200: LDAPCheckAccessSerializer(),
},
operation_id="outposts_ldap_access_check",
)
@action(detail=True)
def check_access(self, request: Request, pk) -> Response:
"""Check access to a single application by slug"""
provider = get_object_or_404(LDAPProvider, pk=pk)
application = get_object_or_404(Application, slug=request.query_params["app_slug"])
engine = PolicyEngine(application, request.user, request)
engine.use_cache = False
engine.build()
result = engine.result
access_response = PolicyResult(result.passing)
response = self.LDAPCheckAccessSerializer(
instance={
"has_search_permission": request.user.has_perm("search_full_directory", provider),
"access": access_response,
}
)
return Response(response.data)

View File

@ -1,66 +0,0 @@
# Generated by Django 5.0.7 on 2024-07-25 14:59
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db import migrations
def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from authentik.core.models import User
from django.apps import apps as real_apps
from django.contrib.auth.management import create_permissions
from guardian.shortcuts import UserObjectPermission
db_alias = schema_editor.connection.alias
# Permissions are only created _after_ migrations are run
# - https://github.com/django/django/blob/43cdfa8b20e567a801b7d0a09ec67ddd062d5ea4/django/contrib/auth/apps.py#L19
# - https://stackoverflow.com/a/72029063/1870445
create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
Permission = apps.get_model("auth", "Permission")
UserObjectPermission = apps.get_model("guardian", "UserObjectPermission")
ContentType = apps.get_model("contenttypes", "ContentType")
new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory")
ct = ContentType.objects.using(db_alias).get(
app_label="authentik_providers_ldap",
model="ldapprovider",
)
for provider in LDAPProvider.objects.using(db_alias).all():
if not provider.search_group:
continue
for user in provider.search_group.users.using(db_alias).all():
UserObjectPermission.objects.using(db_alias).create(
user=user,
permission=new_prem,
object_pk=provider.pk,
content_type=ct,
)
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
("guardian", "0002_generic_permissions_index"),
]
operations = [
migrations.AlterModelOptions(
name="ldapprovider",
options={
"permissions": [("search_full_directory", "Search full LDAP directory")],
"verbose_name": "LDAP Provider",
"verbose_name_plural": "LDAP Providers",
},
),
migrations.RunPython(migrate_search_group),
migrations.RemoveField(
model_name="ldapprovider",
name="search_group",
),
]

View File

@ -7,7 +7,7 @@ from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer
from authentik.core.models import BackchannelProvider
from authentik.core.models import BackchannelProvider, Group
from authentik.crypto.models import CertificateKeyPair
from authentik.outposts.models import OutpostModel
@ -27,6 +27,17 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
help_text=_("DN under which objects are accessible."),
)
search_group = models.ForeignKey(
Group,
null=True,
default=None,
on_delete=models.SET_DEFAULT,
help_text=_(
"Users in this group can do search queries. "
"If not set, every user can execute search queries."
),
)
tls_server_name = models.TextField(
default="",
blank=True,
@ -102,6 +113,3 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
class Meta:
verbose_name = _("LDAP Provider")
verbose_name_plural = _("LDAP Providers")
permissions = [
("search_full_directory", _("Search full LDAP directory")),
]

View File

@ -105,7 +105,7 @@ class ScopeMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-provider-scope-form"
return "ak-property-mapping-scope-form"
@property
def serializer(self) -> type[Serializer]:

View File

@ -29,6 +29,7 @@ class TesOAuth2Introspection(OAuthTestCase):
self.app = Application.objects.create(
name=generate_id(), slug=generate_id(), provider=self.provider
)
self.app.save()
self.user = create_test_admin_user()
self.auth = b64encode(
f"{self.provider.client_id}:{self.provider.client_secret}".encode()
@ -113,41 +114,6 @@ class TesOAuth2Introspection(OAuthTestCase):
},
)
def test_introspect_invalid_provider(self):
"""Test introspection (mismatched provider and token)"""
provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris="",
signing_key=create_test_cert(),
)
auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
token: AccessToken = AccessToken.objects.create(
provider=self.provider,
user=self.user,
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),
HTTP_AUTHORIZATION=f"Basic {auth}",
data={"token": token.token},
)
self.assertEqual(res.status_code, 200)
self.assertJSONEqual(
res.content.decode(),
{
"active": False,
},
)
def test_introspect_invalid_auth(self):
"""Test introspect (invalid auth)"""
res = self.client.post(

View File

@ -62,7 +62,7 @@ urlpatterns = [
api_urlpatterns = [
("providers/oauth2", OAuth2ProviderViewSet),
("propertymappings/provider/scope", ScopeMappingViewSet),
("propertymappings/scope", ScopeMappingViewSet),
("oauth2/authorization_codes", AuthorizationCodeViewSet),
("oauth2/refresh_tokens", RefreshTokenViewSet),
("oauth2/access_tokens", AccessTokenViewSet),

View File

@ -46,10 +46,10 @@ class TokenIntrospectionParams:
if not provider:
raise TokenIntrospectionError
access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first()
access_token = AccessToken.objects.filter(token=raw_token).first()
if access_token:
return TokenIntrospectionParams(access_token, provider)
refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first()
refresh_token = RefreshToken.objects.filter(token=raw_token).first()
if refresh_token:
return TokenIntrospectionParams(refresh_token, provider)
LOGGER.debug("Token does not exist", token=raw_token)

View File

@ -433,20 +433,20 @@ class TokenParams:
app = Application.objects.filter(provider=self.provider).first()
if not app or not app.provider:
raise TokenError("invalid_grant")
with audit_ignore():
self.user, _ = User.objects.update_or_create(
# trim username to ensure the entire username is max 150 chars
# (22 chars being the length of the "template")
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
defaults={
"last_login": timezone.now(),
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
self.user, _ = User.objects.update_or_create(
# trim username to ensure the entire username is max 150 chars
# (22 chars being the length of the "template")
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
defaults={
"attributes": {
USER_ATTRIBUTE_GENERATED: True,
},
)
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
self.user.save()
"last_login": timezone.now(),
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.__check_policy_access(app, request)
Event.new(
@ -470,6 +470,9 @@ class TokenParams:
self.user, created = User.objects.update_or_create(
username=f"{self.provider.name}-{token.get('sub')}",
defaults={
"attributes": {
USER_ATTRIBUTE_GENERATED: True,
},
"last_login": timezone.now(),
"name": (
f"Autogenerated user from application {app.name} (client credentials JWT)"
@ -478,8 +481,6 @@ class TokenParams:
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
self.user.save()
exp = token.get("exp")
if created and exp:
self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp

View File

@ -28,7 +28,7 @@ class ProxyDockerController(DockerController):
labels = super()._get_labels()
labels["traefik.enable"] = "true"
labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
f"({' || '.join([f'Host({host})' for host in hosts])})"
f"({' || '.join([f'Host(`{host}`)' for host in hosts])})"
f" && PathPrefix(`/outpost.goauthentik.io`)"
)
labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"

View File

@ -154,7 +154,6 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
responses={
200: RadiusCheckAccessSerializer(),
},
operation_id="outposts_radius_access_check",
)
@action(detail=True)
def check_access(self, request: Request, pk) -> Response:

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_radius", "0003_radiusproviderpropertymapping"),
]
operations = [
migrations.AlterModelOptions(
name="radiusproviderpropertymapping",
options={
"verbose_name": "Radius Provider Property Mapping",
"verbose_name_plural": "Radius Provider Property Mappings",
},
),
]

View File

@ -70,7 +70,7 @@ class RadiusProviderPropertyMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-provider-radius-form"
return "ak-property-mapping-radius-form"
@property
def serializer(self) -> type[Serializer]:
@ -81,8 +81,8 @@ class RadiusProviderPropertyMapping(PropertyMapping):
return RadiusProviderPropertyMappingSerializer
def __str__(self):
return f"Radius Provider Property Mapping {self.name}"
return f"Radius Property Mapping {self.name}"
class Meta:
verbose_name = _("Radius Provider Property Mapping")
verbose_name_plural = _("Radius Provider Property Mappings")
verbose_name = _("Radius Property Mapping")
verbose_name_plural = _("Radius Property Mappings")

View File

@ -7,7 +7,7 @@ from authentik.providers.radius.api.providers import (
)
api_urlpatterns = [
("propertymappings/provider/radius", RadiusProviderPropertyMappingViewSet),
("propertymappings/radius", RadiusProviderPropertyMappingViewSet),
("outposts/radius", RadiusOutpostConfigViewSet, "radiusprovideroutpost"),
("providers/radius", RadiusProviderViewSet),
]

View File

@ -133,17 +133,6 @@ class SAMLProviderSerializer(ProviderSerializer):
except Provider.application.RelatedObjectDoesNotExist:
return "-"
def validate(self, attrs: dict):
if attrs.get("signing_kp"):
if not attrs.get("sign_assertion") and not attrs.get("sign_response"):
raise ValidationError(
_(
"With a signing keypair selected, at least one of 'Sign assertion' "
"and 'Sign Response' must be selected."
)
)
return super().validate(attrs)
class Meta:
model = SAMLProvider
fields = ProviderSerializer.Meta.fields + [
@ -159,9 +148,6 @@ class SAMLProviderSerializer(ProviderSerializer):
"signature_algorithm",
"signing_kp",
"verification_kp",
"encryption_kp",
"sign_assertion",
"sign_response",
"sp_binding",
"default_relay_state",
"url_download_metadata",

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_saml", "0014_alter_samlprovider_digest_algorithm_and_more"),
]
operations = [
migrations.AlterModelOptions(
name="samlpropertymapping",
options={
"verbose_name": "SAML Provider Property Mapping",
"verbose_name_plural": "SAML Provider Property Mappings",
},
),
]

View File

@ -1,39 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-15 14:52
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_crypto", "0004_alter_certificatekeypair_name"),
("authentik_providers_saml", "0015_alter_samlpropertymapping_options"),
]
operations = [
migrations.AddField(
model_name="samlprovider",
name="encryption_kp",
field=models.ForeignKey(
blank=True,
default=None,
help_text="When selected, incoming assertions are encrypted by the IdP using the public key of the encryption keypair. The assertion is decrypted by the SP using the the private key.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="+",
to="authentik_crypto.certificatekeypair",
verbose_name="Encryption Keypair",
),
),
migrations.AddField(
model_name="samlprovider",
name="sign_assertion",
field=models.BooleanField(default=True),
),
migrations.AddField(
model_name="samlprovider",
name="sign_response",
field=models.BooleanField(default=False),
),
]

View File

@ -144,28 +144,11 @@ class SAMLProvider(Provider):
on_delete=models.SET_NULL,
verbose_name=_("Signing Keypair"),
)
encryption_kp = models.ForeignKey(
CertificateKeyPair,
default=None,
null=True,
blank=True,
help_text=_(
"When selected, incoming assertions are encrypted by the IdP using the public "
"key of the encryption keypair. The assertion is decrypted by the SP using the "
"the private key."
),
on_delete=models.SET_NULL,
verbose_name=_("Encryption Keypair"),
related_name="+",
)
default_relay_state = models.TextField(
default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins")
)
sign_assertion = models.BooleanField(default=True)
sign_response = models.BooleanField(default=False)
@property
def launch_url(self) -> str | None:
"""Use IDP-Initiated SAML flow as launch URL"""
@ -208,7 +191,7 @@ class SAMLPropertyMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-provider-saml-form"
return "ak-property-mapping-saml-form"
@property
def serializer(self) -> type[Serializer]:
@ -221,8 +204,8 @@ class SAMLPropertyMapping(PropertyMapping):
return f"{self.name} ({name})"
class Meta:
verbose_name = _("SAML Provider Property Mapping")
verbose_name_plural = _("SAML Provider Property Mappings")
verbose_name = _("SAML Property Mapping")
verbose_name_plural = _("SAML Property Mappings")
class SAMLProviderImportModel(CreatableType, Provider):

View File

@ -18,11 +18,7 @@ from authentik.providers.saml.processors.authn_request_parser import AuthNReques
from authentik.providers.saml.utils import get_random_id
from authentik.providers.saml.utils.time import get_time_string
from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME
from authentik.sources.saml.exceptions import (
InvalidEncryption,
InvalidSignature,
UnsupportedNameIDFormat,
)
from authentik.sources.saml.exceptions import InvalidSignature, UnsupportedNameIDFormat
from authentik.sources.saml.processors.constants import (
DIGEST_ALGORITHM_TRANSLATION_MAP,
NS_MAP,
@ -50,7 +46,6 @@ class AssertionProcessor:
_issue_instant: str
_assertion_id: str
_response_id: str
_valid_not_before: str
_session_not_on_or_after: str
@ -63,7 +58,6 @@ class AssertionProcessor:
self._issue_instant = get_time_string()
self._assertion_id = get_random_id()
self._response_id = get_random_id()
self._valid_not_before = get_time_string(
timedelta_from_string(self.provider.assertion_valid_not_before)
@ -132,9 +126,7 @@ class AssertionProcessor:
"""Generate AuthnStatement with AuthnContext and ContextClassRef Elements."""
auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement")
auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before
auth_n_statement.attrib["SessionIndex"] = sha256(
self.http_request.session.session_key.encode("ascii")
).hexdigest()
auth_n_statement.attrib["SessionIndex"] = self._assertion_id
auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after
auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext")
@ -264,17 +256,9 @@ class AssertionProcessor:
assertion,
xmlsec.constants.TransformExclC14N,
sign_algorithm_transform,
ns=xmlsec.constants.DSigNs,
ns="ds", # type: ignore
)
assertion.append(signature)
if self.provider.encryption_kp:
encryption = xmlsec.template.encrypted_data_create(
assertion,
xmlsec.constants.TransformAes128Cbc,
self._assertion_id,
ns=xmlsec.constants.DSigNs,
)
assertion.append(encryption)
assertion.append(self.get_assertion_subject())
assertion.append(self.get_assertion_conditions())
@ -289,7 +273,7 @@ class AssertionProcessor:
response.attrib["Version"] = "2.0"
response.attrib["IssueInstant"] = self._issue_instant
response.attrib["Destination"] = self.provider.acs_url
response.attrib["ID"] = self._response_id
response.attrib["ID"] = get_random_id()
if self.auth_n_request.id:
response.attrib["InResponseTo"] = self.auth_n_request.id
@ -302,86 +286,41 @@ class AssertionProcessor:
response.append(self.get_assertion())
return response
def _sign(self, element: Element):
"""Sign an XML element based on the providers' configured signing settings"""
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
)
xmlsec.tree.add_ids(element, ["ID"])
signature_node = xmlsec.tree.find_node(element, xmlsec.constants.NodeSignature)
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + element.attrib["ID"],
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
key_info = xmlsec.template.ensure_key_info(signature_node)
xmlsec.template.add_x509_data(key_info)
ctx = xmlsec.SignatureContext()
key = xmlsec.Key.from_memory(
self.provider.signing_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
None,
)
key.load_cert_from_memory(
self.provider.signing_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
ctx.key = key
try:
ctx.sign(signature_node)
except xmlsec.Error as exc:
raise InvalidSignature() from exc
def _encrypt(self, element: Element, parent: Element):
"""Encrypt SAMLResponse EncryptedAssertion Element"""
manager = xmlsec.KeysManager()
key = xmlsec.Key.from_memory(
self.provider.encryption_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
)
key.load_cert_from_memory(
self.provider.encryption_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
manager.add_key(key)
encryption_context = xmlsec.EncryptionContext(manager)
encryption_context.key = xmlsec.Key.generate(
xmlsec.constants.KeyDataAes, 128, xmlsec.constants.KeyDataTypeSession
)
container = SubElement(parent, f"{{{NS_SAML_ASSERTION}}}EncryptedAssertion")
enc_data = xmlsec.template.encrypted_data_create(
container, xmlsec.Transform.AES128, type=xmlsec.EncryptionType.ELEMENT, ns="xenc"
)
xmlsec.template.encrypted_data_ensure_cipher_value(enc_data)
key_info = xmlsec.template.encrypted_data_ensure_key_info(enc_data, ns="ds")
enc_key = xmlsec.template.add_encrypted_key(key_info, xmlsec.Transform.RSA_OAEP)
xmlsec.template.encrypted_data_ensure_cipher_value(enc_key)
try:
enc_data = encryption_context.encrypt_xml(enc_data, element)
except xmlsec.Error as exc:
raise InvalidEncryption() from exc
parent.remove(enc_data)
container.append(enc_data)
def build_response(self) -> str:
"""Build string XML Response and sign if signing is enabled."""
root_response = self.get_response()
if self.provider.signing_kp:
if self.provider.sign_assertion:
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
self._sign(assertion)
if self.provider.sign_response:
response = root_response.xpath("//samlp:Response", namespaces=NS_MAP)[0]
self._sign(response)
if self.provider.encryption_kp:
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
)
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
self._encrypt(assertion, root_response)
xmlsec.tree.add_ids(assertion, ["ID"])
signature_node = xmlsec.tree.find_node(assertion, xmlsec.constants.NodeSignature)
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + self._assertion_id,
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
key_info = xmlsec.template.ensure_key_info(signature_node)
xmlsec.template.add_x509_data(key_info)
ctx = xmlsec.SignatureContext()
key = xmlsec.Key.from_memory(
self.provider.signing_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
None,
)
key.load_cert_from_memory(
self.provider.signing_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
ctx.key = key
try:
ctx.sign(signature_node)
except xmlsec.Error as exc:
raise InvalidSignature() from exc
return etree.tostring(root_response).decode("utf-8") # nosec

View File

@ -126,7 +126,7 @@ class MetadataProcessor:
entity_descriptor,
xmlsec.constants.TransformExclC14N,
sign_algorithm_transform,
ns=xmlsec.constants.DSigNs,
ns="ds", # type: ignore
)
entity_descriptor.append(signature)

View File

@ -8,7 +8,7 @@ from rest_framework.test import APITestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import FlowDesignation
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture
@ -29,52 +29,12 @@ class TestSAMLProviderAPI(APITestCase):
name=generate_id(),
authorization_flow=create_test_flow(),
)
response = self.client.get(
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
)
self.assertEqual(200, response.status_code)
Application.objects.create(name=generate_id(), provider=provider, slug=generate_id())
response = self.client.get(
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
)
self.assertEqual(200, response.status_code)
def test_create_validate_signing_kp(self):
"""Test create"""
cert = create_test_cert()
response = self.client.post(
reverse("authentik_api:samlprovider-list"),
data={
"name": generate_id(),
"authorization_flow": create_test_flow().pk,
"acs_url": "http://localhost",
"signing_kp": cert.pk,
},
)
self.assertEqual(400, response.status_code)
self.assertJSONEqual(
response.content,
{
"non_field_errors": [
(
"With a signing keypair selected, at least one "
"of 'Sign assertion' and 'Sign Response' must be selected."
)
]
},
)
response = self.client.post(
reverse("authentik_api:samlprovider-list"),
data={
"name": generate_id(),
"authorization_flow": create_test_flow().pk,
"acs_url": "http://localhost",
"signing_kp": cert.pk,
"sign_assertion": True,
},
)
self.assertEqual(201, response.status_code)
def test_metadata(self):
"""Test metadata export (normal)"""
self.client.logout()

View File

@ -78,12 +78,12 @@ class TestAuthNRequest(TestCase):
@apply_blueprint("system/providers-saml.yaml")
def setUp(self):
self.cert = create_test_cert()
cert = create_test_cert()
self.provider: SAMLProvider = SAMLProvider.objects.create(
authorization_flow=create_test_flow(),
acs_url="http://testserver/source/saml/provider/acs/",
signing_kp=self.cert,
verification_kp=self.cert,
signing_kp=cert,
verification_kp=cert,
)
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save()
@ -91,8 +91,8 @@ class TestAuthNRequest(TestCase):
slug="provider",
issuer="authentik",
pre_authentication_flow=create_test_flow(),
signing_kp=self.cert,
verification_kp=self.cert,
signing_kp=cert,
verification_kp=cert,
)
def test_signed_valid(self):
@ -112,34 +112,7 @@ class TestAuthNRequest(TestCase):
self.assertEqual(parsed_request.id, request_proc.request_id)
self.assertEqual(parsed_request.relay_state, "test_state")
def test_request_encrypt(self):
"""Test full SAML Request/Response flow, fully encrypted"""
self.provider.encryption_kp = self.cert
self.provider.save()
self.source.encryption_kp = self.cert
self.source.save()
http_request = get_request("/")
# First create an AuthNRequest
request_proc = RequestProcessor(self.source, http_request, "test_state")
request = request_proc.build_auth_n()
# To get an assertion we need a parsed request (parsed by provider)
parsed_request = AuthNRequestParser(self.provider).parse(
b64encode(request.encode()).decode(), "test_state"
)
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse()
def test_request_signed(self):
def test_request_full_signed(self):
"""Test full SAML Request/Response flow, fully signed"""
http_request = get_request("/")
@ -162,36 +135,6 @@ class TestAuthNRequest(TestCase):
response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse()
def test_request_signed_both(self):
"""Test full SAML Request/Response flow, fully signed"""
self.provider.sign_assertion = True
self.provider.sign_response = True
self.provider.save()
http_request = get_request("/")
# First create an AuthNRequest
request_proc = RequestProcessor(self.source, http_request, "test_state")
request = request_proc.build_auth_n()
# To get an assertion we need a parsed request (parsed by provider)
parsed_request = AuthNRequestParser(self.provider).parse(
b64encode(request.encode()).decode(), "test_state"
)
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Ensure both response and assertion ID are in the response twice (once as ID attribute,
# once as ds:Reference URI)
self.assertEqual(response.count(response_proc._assertion_id), 2)
self.assertEqual(response.count(response_proc._response_id), 2)
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse()
def test_request_id_invalid(self):
"""Test generated AuthNRequest with invalid request ID"""
http_request = get_request("/")

View File

@ -54,11 +54,7 @@ class TestServiceProviderMetadataParser(TestCase):
request = self.factory.get("/")
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
schema = etree.XMLSchema(
etree.parse(
source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()
) # nosec
)
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata))
def test_schema_want_authn_requests_signed(self):

View File

@ -47,9 +47,7 @@ class TestSchema(TestCase):
metadata = lxml_from_string(request)
schema = etree.XMLSchema(
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata))
def test_response_schema(self):
@ -70,7 +68,5 @@ class TestSchema(TestCase):
metadata = lxml_from_string(response)
schema = etree.XMLSchema(
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata))

View File

@ -44,6 +44,6 @@ urlpatterns = [
]
api_urlpatterns = [
("propertymappings/provider/saml", SAMLPropertyMappingViewSet),
("propertymappings/saml", SAMLPropertyMappingViewSet),
("providers/saml", SAMLProviderViewSet),
]

View File

@ -6,7 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
from authentik.providers.scim.tasks import scim_sync
class SCIMProviderSerializer(ProviderSerializer):
@ -42,4 +42,3 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
search_fields = ["name", "url"]
ordering = ["name", "url"]
sync_single_task = scim_sync
sync_objects_task = scim_sync_objects

View File

@ -1,11 +1,8 @@
"""Group client"""
from itertools import batched
from django.db import transaction
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp
from pydanticscim.responses import PatchOp, PatchOperation
from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager
@ -20,7 +17,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
SCIMRequestException,
)
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import (
SCIMMapping,
@ -59,22 +56,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
if not scim_group.externalId:
scim_group.externalId = str(obj.pk)
if not self._config.patch.supported:
users = list(obj.users.order_by("id").values_list("id", flat=True))
connections = SCIMProviderUser.objects.filter(
provider=self.provider, user__pk__in=users
)
members = []
for user in connections:
members.append(
GroupMember(
value=user.scim_id,
)
users = list(obj.users.order_by("id").values_list("id", flat=True))
connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
members = []
for user in connections:
members.append(
GroupMember(
value=user.scim_id,
)
if members:
scim_group.members = members
else:
del scim_group.members
)
if members:
scim_group.members = members
return scim_group
def delete(self, obj: Group):
@ -101,53 +93,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
scim_id = response.get("id")
if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`")
connection = SCIMProviderGroup.objects.create(
return SCIMProviderGroup.objects.create(
provider=self.provider, group=group, scim_id=scim_id
)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(connection, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group"""
scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id
try:
if self._config.patch.supported:
return self._update_patch(group, scim_group, connection)
return self._update_put(group, scim_group, connection)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
def _update_patch(
self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup
):
"""Update a group via PATCH request"""
# Patch group's attributes instead of replacing it and re-adding users if we can
self._request(
"PATCH",
f"/Groups/{connection.scim_id}",
json=PatchRequest(
Operations=[
PatchOperation(
op=PatchOp.replace,
path=None,
value=scim_group.model_dump(mode="json", exclude_unset=True),
)
]
).model_dump(
mode="json",
exclude_unset=True,
exclude_none=True,
),
)
return self.patch_compare_users(group)
def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup):
"""Update a group via PUT request"""
try:
self._request(
return self._request(
"PUT",
f"/Groups/{connection.scim_id}",
json=scim_group.model_dump(
@ -155,25 +110,31 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True,
),
)
return self.patch_compare_users(group)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
except (SCIMRequestException, ObjectExistsSyncException):
# Some providers don't support PUT on groups, so this is mainly a fix for the initial
# sync, send patch add requests for all the users the group currently has
return self._update_patch(group, scim_group, connection)
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
# Also update the group name
return self._patch(
scim_group.id,
PatchOperation(
op=PatchOp.replace,
path="displayName",
value=scim_group.displayName,
),
)
def update_group(self, group: Group, action: Direction, users_set: set[int]):
"""Update a group, either using PUT to replace it or PATCH if supported"""
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
if self._config.patch.supported:
if action == Direction.add:
return self._patch_add_users(scim_group, users_set)
return self._patch_add_users(group, users_set)
if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set)
return self._patch_remove_users(group, users_set)
try:
return self.write(group)
except SCIMRequestException as exc:
@ -181,98 +142,35 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
# Assume that provider does not support PUT and also doesn't support
# ServiceProviderConfig, so try PATCH as a fallback
if action == Direction.add:
return self._patch_add_users(scim_group, users_set)
return self._patch_add_users(group, users_set)
if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set)
return self._patch_remove_users(group, users_set)
raise exc
def _patch_chunked(
def _patch(
self,
group_id: str,
*ops: PatchOperation,
):
"""Helper function that chunks patch requests based on the maxOperations attribute.
This is not strictly according to specs but there's nothing in the schema that allows the
us to know what the maximum patch operations per request should be."""
chunk_size = self._config.bulk.maxOperations
if chunk_size < 1:
chunk_size = len(ops)
if len(ops) < 1:
return
for chunk in batched(ops, chunk_size):
req = PatchRequest(Operations=list(chunk))
self._request(
"PATCH",
f"/Groups/{group_id}",
json=req.model_dump(
mode="json",
),
)
req = PatchRequest(Operations=ops)
self._request(
"PATCH",
f"/Groups/{group_id}",
json=req.model_dump(
mode="json",
),
)
@transaction.atomic
def patch_compare_users(self, group: Group):
"""Compare users with a SCIM group and add/remove any differences"""
# Get scim group first
def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
# Get a list of all users in the authentik group
raw_users_should = list(group.users.order_by("id").values_list("id", flat=True))
# Lookup the SCIM IDs of the users
users_should: list[str] = list(
SCIMProviderUser.objects.filter(
user__pk__in=raw_users_should, provider=self.provider
).values_list("scim_id", flat=True)
)
if len(raw_users_should) != len(users_should):
self.logger.warning(
"User count mismatch, not all users in the group are synced to SCIM yet.",
group=group,
)
# Get current group status
current_group = SCIMGroupSchema.model_validate(
self._request("GET", f"/Groups/{scim_group.scim_id}")
)
users_to_add = []
users_to_remove = []
# Check users currently in group and if they shouldn't be in the group and remove them
for user in current_group.members or []:
if user.value not in users_should:
users_to_remove.append(user.value)
# Check users that should be in the group and add them
for user in users_should:
if len([x for x in current_group.members if x.value == user]) < 1:
users_to_add.append(user)
# Only send request if we need to make changes
if len(users_to_add) < 1 and len(users_to_remove) < 1:
return
return self._patch_chunked(
scim_group.scim_id,
*[
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x}],
)
for x in users_to_add
],
*[
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x}],
)
for x in users_to_remove
],
)
def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -280,22 +178,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch_chunked(
self._patch(
scim_group.scim_id,
*[
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x}],
)
for x in user_ids
],
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x} for x in user_ids],
),
)
def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
def _patch_remove_users(self, group: Group, users_set: set[int]):
"""Remove users in users_set from group"""
if len(users_set) < 1:
return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
user_ids = list(
SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
@ -303,14 +204,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
)
if len(user_ids) < 1:
return
self._patch_chunked(
self._patch(
scim_group.scim_id,
*[
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x}],
)
for x in user_ids
],
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x} for x in user_ids],
),
)

View File

@ -1,12 +1,9 @@
"""Custom SCIM schemas"""
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk as BaseBulk
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import (
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
)
@ -32,16 +29,10 @@ class Group(BaseGroup):
meta: dict | None = None
class Bulk(BaseBulk):
maxOperations: int = Field()
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""ServiceProviderConfig with fallback"""
_is_fallback: bool | None = False
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
@property
def is_fallback(self) -> bool:
@ -54,7 +45,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""Get default configuration, which doesn't support any optional features as fallback"""
return ServiceProviderConfiguration(
patch=Patch(supported=False),
bulk=Bulk(supported=False, maxOperations=0),
bulk=Bulk(supported=False),
filter=Filter(supported=False),
changePassword=ChangePassword(supported=False),
sort=Sort(supported=False),
@ -69,12 +60,6 @@ class PatchRequest(BasePatchRequest):
schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",)
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
path: str | None
class SCIMError(BaseSCIMError):
"""SCIM error with optional status code"""

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_scim", "0008_rename_scimgroup_scimprovidergroup_and_more"),
]
operations = [
migrations.AlterModelOptions(
name="scimmapping",
options={
"verbose_name": "SCIM Provider Mapping",
"verbose_name_plural": "SCIM Provider Mappings",
},
),
]

View File

@ -133,7 +133,7 @@ class SCIMMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-provider-scim-form"
return "ak-property-mapping-scim-form"
@property
def serializer(self) -> type[Serializer]:
@ -142,8 +142,8 @@ class SCIMMapping(PropertyMapping):
return SCIMMappingSerializer
def __str__(self):
return f"SCIM Provider Mapping {self.name}"
return f"SCIM Mapping {self.name}"
class Meta:
verbose_name = _("SCIM Provider Mapping")
verbose_name_plural = _("SCIM Provider Mappings")
verbose_name = _("SCIM Mapping")
verbose_name_plural = _("SCIM Mappings")

View File

@ -252,118 +252,3 @@ class SCIMMembershipTests(TestCase):
],
},
)
def test_member_add_save(self):
"""Test member add + save"""
config = ServiceProviderConfiguration.default()
config.patch.supported = True
user_scim_id = generate_id()
group_scim_id = generate_id()
uid = generate_id()
group = Group.objects.create(
name=uid,
)
user = User.objects.create(username=generate_id())
# Test initial sync of group creation
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.post(
"https://localhost/Users",
json={
"id": user_scim_id,
},
)
mocker.post(
"https://localhost/Groups",
json={
"id": group_scim_id,
},
)
self.configure()
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual(
mocker.request_history[3].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [],
"active": True,
"externalId": user.uid,
"name": {"familyName": " ", "formatted": " ", "givenName": ""},
"displayName": "",
"userName": user.username,
},
)
self.assertJSONEqual(
mocker.request_history[5].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
)
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.get(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
mocker.patch(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
group.users.add(user)
group.save()
self.assertEqual(mocker.call_count, 5)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "PATCH")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "PATCH")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
"path": "members",
"value": [{"value": user_scim_id}],
}
],
},
)
self.assertJSONEqual(
mocker.request_history[3].body,
{
"Operations": [
{
"op": "replace",
"value": {
"id": group_scim_id,
"displayName": group.name,
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
},
}
]
},
)

View File

@ -13,5 +13,5 @@ api_urlpatterns = [
("providers/scim", SCIMProviderViewSet),
("providers/scim_users", SCIMProviderUserViewSet),
("providers/scim_groups", SCIMProviderGroupViewSet),
("propertymappings/provider/scim", SCIMMappingViewSet),
("propertymappings/scim", SCIMMappingViewSet),
]

View File

@ -2,7 +2,7 @@
from uuid import uuid4
from django.contrib.auth.management import _get_all_permissions
from django.contrib.auth.models import Permission
from django.db import models
from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _
@ -10,26 +10,28 @@ from guardian.shortcuts import assign_perm
from rest_framework.serializers import BaseSerializer
from authentik.lib.models import SerializerModel
from authentik.lib.utils.reflection import get_apps
def get_permission_choices():
all_perms = []
for app in get_apps():
for model in app.get_models():
for perm, _desc in _get_all_permissions(model._meta):
all_perms.append((model, perm))
return sorted(
[
(
f"{model._meta.app_label}.{perm}",
f"{model._meta.app_label}.{perm}",
)
for model, perm in all_perms
]
def get_permissions():
return (
Permission.objects.all()
.select_related("content_type")
.filter(
content_type__app_label__startswith="authentik",
)
)
def get_permission_choices() -> list[tuple[str, str]]:
return [
(
f"{x.content_type.app_label}.{x.codename}",
f"{x.content_type.app_label}.{x.codename}",
)
for x in get_permissions()
]
class Role(SerializerModel):
"""RBAC role, which can have different permissions (both global and per-object) attached
to it."""

View File

@ -87,11 +87,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
def _get_startup_tasks_default_tenant() -> list[Callable]:
"""Get all tasks to be run on startup for the default tenant"""
from authentik.outposts.tasks import outpost_connection_discovery
return [
outpost_connection_discovery,
]
return []
def _get_startup_tasks_all_tenants() -> list[Callable]:

View File

@ -2,7 +2,6 @@
from collections.abc import Callable
from hashlib import sha512
from ipaddress import ip_address
from time import perf_counter, time
from typing import Any
@ -175,7 +174,6 @@ class ClientIPMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
self.logger = get_logger().bind()
def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str:
"""Attempt to get the client's IP by checking common HTTP Headers.
@ -187,16 +185,11 @@ class ClientIPMiddleware:
"HTTP_X_FORWARDED_FOR",
"REMOTE_ADDR",
)
try:
for _header in headers:
if _header in meta:
ips: list[str] = meta.get(_header).split(",")
# Ensure the IP parses as a valid IP
return str(ip_address(ips[0].strip()))
return self.default_ip
except ValueError as exc:
self.logger.debug("Invalid remote IP", exc=exc)
return self.default_ip
for _header in headers:
if _header in meta:
ips: list[str] = meta.get(_header).split(",")
return ips[0].strip()
return self.default_ip
# FIXME: this should probably not be in `root` but rather in a middleware in `outposts`
# but for now it's fine
@ -233,11 +226,7 @@ class ClientIPMiddleware:
Scope.get_isolation_scope().set_user(user)
# Set the outpost service account on the request
setattr(request, self.request_attr_outpost_user, user)
try:
return str(ip_address(delegated_ip))
except ValueError as exc:
self.logger.debug("Invalid remote IP from Outpost", exc=exc)
return None
return delegated_ip
def _get_client_ip(self, request: HttpRequest | None) -> str:
"""Attempt to get the client's IP by checking common HTTP Headers.

View File

@ -9,7 +9,6 @@ import orjson
from celery.schedules import crontab
from django.conf import ImproperlyConfigured
from sentry_sdk import set_tag
from xmlsec import enable_debug_trace
from authentik import __version__
from authentik.lib.config import CONFIG, redis_url
@ -521,7 +520,6 @@ if DEBUG:
"rest_framework.renderers.BrowsableAPIRenderer"
)
SHARED_APPS.insert(SHARED_APPS.index("django.contrib.staticfiles"), "daphne")
enable_debug_trace(True)
TENANT_APPS.append("authentik.core")

View File

@ -1,7 +1,6 @@
"""authentik storage backends"""
import os
from urllib.parse import parse_qsl, urlsplit
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
@ -111,34 +110,3 @@ class S3Storage(BaseS3Storage):
if self.querystring_auth:
return url
return self._strip_signing_parameters(url)
def _strip_signing_parameters(self, url):
# Boto3 does not currently support generating URLs that are unsigned. Instead
# we take the signed URLs and strip any querystring params related to signing
# and expiration.
# Note that this may end up with URLs that are still invalid, especially if
# params are passed in that only work with signed URLs, e.g. response header
# params.
# The code attempts to strip all query parameters that match names of known
# parameters from v2 and v4 signatures, regardless of the actual signature
# version used.
split_url = urlsplit(url)
qs = parse_qsl(split_url.query, keep_blank_values=True)
blacklist = {
"x-amz-algorithm",
"x-amz-credential",
"x-amz-date",
"x-amz-expires",
"x-amz-signedheaders",
"x-amz-signature",
"x-amz-security-token",
"awsaccesskeyid",
"expires",
"signature",
}
filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist)
# Note: Parameters that did not have a value in the original query string will
# have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar=
joined_qs = ("=".join(keyval) for keyval in filtered_qs)
split_url = split_url._replace(query="&".join(joined_qs))
return split_url.geturl()

View File

@ -3,7 +3,6 @@
from typing import Any
from django.core.cache import cache
from django.utils.translation import gettext_lazy as _
from drf_spectacular.utils import extend_schema, inline_serializer
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
@ -40,8 +39,9 @@ class LDAPSourceSerializer(SourceSerializer):
"""Get cached source connectivity"""
return cache.get(CACHE_KEY_STATUS + source.slug, None)
def validate_sync_users_password(self, sync_users_password: bool) -> bool:
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Check that only a single source has password_sync on"""
sync_users_password = attrs.get("sync_users_password", True)
if sync_users_password:
sources = LDAPSource.objects.filter(sync_users_password=True)
if self.instance:
@ -49,31 +49,11 @@ class LDAPSourceSerializer(SourceSerializer):
if sources.exists():
raise ValidationError(
{
"sync_users_password": _(
"sync_users_password": (
"Only a single LDAP Source with password synchronization is allowed"
)
}
)
return sync_users_password
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Validate property mappings with sync_ flags"""
types = ["user", "group"]
for type in types:
toggle_value = attrs.get(f"sync_{type}s", False)
mappings_field = f"{type}_property_mappings"
mappings_value = attrs.get(mappings_field, [])
if toggle_value and len(mappings_value) == 0:
raise ValidationError(
{
mappings_field: _(
(
"When 'Sync {type}s' is enabled, '{type}s property "
"mappings' cannot be empty."
).format(type=type)
)
}
)
return super().validate(attrs)
class Meta:
@ -186,12 +166,11 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
for sync_class in SYNC_CLASSES:
class_name = sync_class.name()
all_objects.setdefault(class_name, [])
for page in sync_class(source).get_objects(size_limit=10):
for obj in page:
obj: dict
obj.pop("raw_attributes", None)
obj.pop("raw_dn", None)
all_objects[class_name].append(obj)
for obj in sync_class(source).get_objects(size_limit=10):
obj: dict
obj.pop("raw_attributes", None)
obj.pop("raw_dn", None)
all_objects[class_name].append(obj)
return Response(data=all_objects)

View File

@ -290,7 +290,7 @@ class LDAPSourcePropertyMapping(PropertyMapping):
@property
def component(self) -> str:
return "ak-property-mapping-source-ldap-form"
return "ak-property-mapping-ldap-source-form"
@property
def serializer(self) -> type[Serializer]:

View File

@ -26,16 +26,17 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
"""Ensure that source is synced on save (if enabled)"""
if not instance.enabled:
return
ldap_connectivity_check.delay(instance.pk)
# Don't sync sources when they don't have any property mappings. This will only happen if:
# - the user forgets to set them or
# - the source is newly created, this is the first save event
# and the mappings are created with an m2m event
if instance.sync_users and not instance.user_property_mappings.exists():
return
if instance.sync_groups and not instance.group_property_mappings.exists():
if (
not instance.user_property_mappings.exists()
or not instance.group_property_mappings.exists()
):
return
ldap_sync_single.delay(instance.pk)
ldap_connectivity_check.delay(instance.pk)
@receiver(password_validate)

View File

@ -38,11 +38,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter,
search_scope=SUBTREE,
attributes=[
ALL_ATTRIBUTES,
ALL_OPERATIONAL_ATTRIBUTES,
self._source.object_uniqueness_field,
],
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
**kwargs,
)
@ -57,9 +53,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
continue
attributes = group.get("attributes", {})
group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
if not attributes.get(self._source.object_uniqueness_field):
if self._source.object_uniqueness_field not in attributes:
self.message(
f"Uniqueness field not found/not set in attributes: '{group_dn}'",
f"Cannot find uniqueness field in attributes: '{group_dn}'",
attributes=attributes.keys(),
dn=group_dn,
)

Some files were not shown because too many files have changed in this diff Show More