Compare commits
223 Commits
static-con
...
revert-rev
Author | SHA1 | Date | |
---|---|---|---|
de1ec20d7a | |||
430678326d | |||
8fb7b5160d | |||
14a6430e21 | |||
7c32fc8c17 | |||
ed0a9d6a0a | |||
53143e0c40 | |||
178e010ed4 | |||
49b666fbde | |||
c343e3a7f4 | |||
5febf3ce5b | |||
b8c5bd678b | |||
4dd5eccbaa | |||
2410884006 | |||
3cb921b0f9 | |||
535f92981f | |||
955d69d5b7 | |||
fb01d8e96a | |||
6d39efd3e3 | |||
3020c31bcd | |||
22412729e2 | |||
a02868a27d | |||
bfbb4a8ebc | |||
6c0e827677 | |||
29884cbf81 | |||
0f02985b0c | |||
2244e026c2 | |||
429c03021c | |||
f47e8d9d72 | |||
3e7d2587c4 | |||
55a38d4a36 | |||
6021bb932d | |||
54a5d95717 | |||
a0a1275452 | |||
919aa5df59 | |||
cedf7cf683 | |||
cbc5a1c39d | |||
5f6b69c998 | |||
cf065db3d5 | |||
86c65325ce | |||
2b8e10e979 | |||
9298807275 | |||
ed56d6ac50 | |||
8c07b385ad | |||
880db7a86c | |||
99c1250ba5 | |||
5ce126ac83 | |||
dfa21d0725 | |||
e7e4af3894 | |||
931d6ec579 | |||
ff45acb25c | |||
c96557ff2d | |||
734feac4ae | |||
b17a9ed145 | |||
2bef7695db | |||
df472dd842 | |||
98d201d34c | |||
47e89602ab | |||
ceb0851452 | |||
cac2593658 | |||
1c9705bfaa | |||
9e2566cec4 | |||
5bdef1c4f6 | |||
ae41ccd862 | |||
337956672f | |||
cf160f800d | |||
e9822cd937 | |||
5244f64be4 | |||
0df4824fd4 | |||
ea22abc75d | |||
b09bab7543 | |||
5aedc8a5f2 | |||
2f3ae0f607 | |||
e3674426b7 | |||
df915d3a5e | |||
4949c31860 | |||
4580dec06b | |||
56de969640 | |||
413902508d | |||
64af0ccba6 | |||
673db53777 | |||
8df7716d90 | |||
19bb2de13f | |||
a218fd7628 | |||
78cfb50a90 | |||
2033d52dc2 | |||
be00f47ddc | |||
2cc5f4b273 | |||
4e8f3407a4 | |||
7f861cc2a1 | |||
7bf58d0ba2 | |||
fffcb00f39 | |||
77ee868573 | |||
6aaec08496 | |||
cc15584650 | |||
e55e446b89 | |||
76088e48b5 | |||
4165a0a6b2 | |||
647fefe5ce | |||
723dccdae3 | |||
c82f747e5e | |||
43406e2464 | |||
a0ff0bef85 | |||
bedf548a5f | |||
976e81c1dd | |||
ad733033d7 | |||
ba686f6a93 | |||
dc50be1e13 | |||
205686d252 | |||
6d589013e6 | |||
2d6433ca9a | |||
b5f07acb26 | |||
ea8702077c | |||
6593357115 | |||
6daed865c1 | |||
c48a21707a | |||
e857770c0a | |||
add74c8799 | |||
12d854035d | |||
57dd4ae91d | |||
37fbc98177 | |||
14f216eb40 | |||
1209dd022e | |||
c96f13ac66 | |||
5e6874cc1f | |||
fb5053ec83 | |||
6f7dc2c543 | |||
542b69b224 | |||
c15c0cbe86 | |||
c6fe0c1d85 | |||
07f0666a6f | |||
51609d696d | |||
c0d08df161 | |||
643a97f0a5 | |||
155a31fd70 | |||
c6f9d5df7b | |||
ea85331a7e | |||
4f4c5253dd | |||
83b2fc36df | |||
d99d2b8bdc | |||
9b96d04b3a | |||
ca5b99eb16 | |||
4c1676e97c | |||
81855cf2fe | |||
bd904027be | |||
0ffc97db15 | |||
2c515b1e17 | |||
f8900fbaf3 | |||
0f4a98d9c6 | |||
8853f25b45 | |||
1c40f7b95a | |||
9b5d6ec1af | |||
36d29a9ae1 | |||
0606b1aba4 | |||
03d5dad867 | |||
38a9e46af3 | |||
5eb848e376 | |||
61a293daad | |||
edf3300944 | |||
5d9c40eac8 | |||
6ebfbcb66e | |||
bf0235c113 | |||
895cd23b57 | |||
c908d9e95e | |||
a07fd8d54b | |||
39a46a6dc4 | |||
ad71960d77 | |||
2a384511f5 | |||
4dcc104947 | |||
71fe526e47 | |||
03e3f516ac | |||
3b59333246 | |||
4e800c14cb | |||
789b29a3e7 | |||
857b6e63a0 | |||
edc937dd78 | |||
d98b6f29d4 | |||
53ba2a0ca8 | |||
ae364292e6 | |||
f15bc2df97 | |||
b27d49e55f | |||
e0d2beb225 | |||
2313b4755b | |||
1cffadecb0 | |||
5e163d6da1 | |||
0626e18674 | |||
e986a62a12 | |||
e25afcb84a | |||
bb95613104 | |||
89dfac2f57 | |||
31462b55e6 | |||
60337c1cf0 | |||
343d3bb1fb | |||
11fe86c4f6 | |||
963ce085e4 | |||
3642b89ab0 | |||
8cfb371ed3 | |||
6e74edb9f2 | |||
397905f8f0 | |||
7fd35b1dfc | |||
9ba03f5439 | |||
1139d6d27c | |||
077fd966c2 | |||
bd41822a57 | |||
dfd3d76434 | |||
397e98906d | |||
65d8da8c64 | |||
5b435297c5 | |||
f792fd42f6 | |||
70c0fdd5fa | |||
9b636eba01 | |||
a982224502 | |||
6a16cccb40 | |||
6dac91e2b4 | |||
3e2d0532d1 | |||
4e1300650b | |||
06b3ed0c9c | |||
395ad722b7 | |||
9917d81246 | |||
2a87687d34 | |||
a726c2260a | |||
44e0bfd4ef | |||
8d0b362c9c |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2025.2.4
|
||||
current_version = 2025.4.0
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||
|
1
.github/workflows/api-py-publish.yml
vendored
1
.github/workflows/api-py-publish.yml
vendored
@ -30,7 +30,6 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version-file: "pyproject.toml"
|
||||
cache: "poetry"
|
||||
- name: Generate API Client
|
||||
run: make gen-client-py
|
||||
- name: Publish package
|
||||
|
13
.github/workflows/ci-main.yml
vendored
13
.github/workflows/ci-main.yml
vendored
@ -70,22 +70,18 @@ jobs:
|
||||
- name: checkout stable
|
||||
run: |
|
||||
# Copy current, latest config to local
|
||||
# Temporarly comment the .github backup while migrating to uv
|
||||
cp authentik/lib/default.yml local.env.yml
|
||||
# cp -R .github ..
|
||||
cp -R .github ..
|
||||
cp -R scripts ..
|
||||
git checkout $(git tag --sort=version:refname | grep '^version/' | grep -vE -- '-rc[0-9]+$' | tail -n1)
|
||||
# rm -rf .github/ scripts/
|
||||
# mv ../.github ../scripts .
|
||||
rm -rf scripts/
|
||||
mv ../scripts .
|
||||
rm -rf .github/ scripts/
|
||||
mv ../.github ../scripts .
|
||||
- name: Setup authentik env (stable)
|
||||
uses: ./.github/actions/setup
|
||||
with:
|
||||
postgresql_version: ${{ matrix.psql }}
|
||||
continue-on-error: true
|
||||
- name: run migrations to stable
|
||||
run: poetry run python -m lifecycle.migrate
|
||||
run: uv run python -m lifecycle.migrate
|
||||
- name: checkout current code
|
||||
run: |
|
||||
set -x
|
||||
@ -232,7 +228,6 @@ jobs:
|
||||
needs:
|
||||
- lint
|
||||
- test-migrations
|
||||
- test-migrations-from-stable
|
||||
- test-unittest
|
||||
- test-integration
|
||||
- test-e2e
|
||||
|
45
.github/workflows/packages-npm-publish.yml
vendored
Normal file
45
.github/workflows/packages-npm-publish.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: authentik-packages-npm-publish
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- packages/docusaurus-config/**
|
||||
- packages/eslint-config/**
|
||||
- packages/prettier-config/**
|
||||
- packages/tsconfig/**
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
publish:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
package:
|
||||
- docusaurus-config
|
||||
- eslint-config
|
||||
- prettier-config
|
||||
- tsconfig
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version-file: packages/${{ matrix.package }}/package.json
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c
|
||||
with:
|
||||
files: |
|
||||
packages/${{ matrix.package }}/package.json
|
||||
- name: Publish package
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: packages/${{ matrix.package}}
|
||||
run: |
|
||||
npm ci
|
||||
npm run build
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_PUBLISH_TOKEN }}
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -11,6 +11,10 @@ local_settings.py
|
||||
db.sqlite3
|
||||
media
|
||||
|
||||
# Node
|
||||
|
||||
node_modules
|
||||
|
||||
# If your build process includes running collectstatic, then you probably don't need or want to include staticfiles/
|
||||
# in your Git repository. Update and uncomment the following line accordingly.
|
||||
# <django-project-name>/staticfiles/
|
||||
|
47
.prettierignore
Normal file
47
.prettierignore
Normal file
@ -0,0 +1,47 @@
|
||||
# Prettier Ignorefile
|
||||
|
||||
## Static Files
|
||||
**/LICENSE
|
||||
|
||||
authentik/stages/**/*
|
||||
|
||||
## Build asset directories
|
||||
coverage
|
||||
dist
|
||||
out
|
||||
.docusaurus
|
||||
website/docs/developer-docs/api/**/*
|
||||
|
||||
## Environment
|
||||
*.env
|
||||
|
||||
## Secrets
|
||||
*.secrets
|
||||
|
||||
## Yarn
|
||||
.yarn/**/*
|
||||
|
||||
## Node
|
||||
node_modules
|
||||
coverage
|
||||
|
||||
## Configs
|
||||
*.log
|
||||
*.yaml
|
||||
*.yml
|
||||
|
||||
# Templates
|
||||
# TODO: Rename affected files to *.template.* or similar.
|
||||
*.html
|
||||
*.mdx
|
||||
*.md
|
||||
|
||||
## Import order matters
|
||||
poly.ts
|
||||
src/locale-codes.ts
|
||||
src/locales/
|
||||
|
||||
# Storybook
|
||||
storybook-static/
|
||||
.storybook/css-import-maps*
|
||||
|
@ -23,6 +23,8 @@ docker-compose.yml @goauthentik/infrastructure
|
||||
Makefile @goauthentik/infrastructure
|
||||
.editorconfig @goauthentik/infrastructure
|
||||
CODEOWNERS @goauthentik/infrastructure
|
||||
# Web packages
|
||||
packages/ @goauthentik/frontend
|
||||
# Web
|
||||
web/ @goauthentik/frontend
|
||||
tests/wdio/ @goauthentik/frontend
|
||||
|
@ -40,7 +40,8 @@ COPY ./web /work/web/
|
||||
COPY ./website /work/website/
|
||||
COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
|
||||
|
||||
RUN npm run build
|
||||
RUN npm run build && \
|
||||
npm run build:sfe
|
||||
|
||||
# Stage 3: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/golang:1.24-bookworm AS go-builder
|
||||
@ -94,9 +95,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
|
||||
# Stage 5: Download uv
|
||||
FROM ghcr.io/astral-sh/uv:0.6.14 AS uv
|
||||
FROM ghcr.io/astral-sh/uv:0.6.16 AS uv
|
||||
# Stage 6: Base python image
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.9-slim-bookworm-fips AS python-base
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.10-slim-bookworm-fips AS python-base
|
||||
|
||||
ENV VENV_PATH="/ak-root/.venv" \
|
||||
PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \
|
||||
|
@ -20,8 +20,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
|
||||
|
||||
| Version | Supported |
|
||||
| --------- | --------- |
|
||||
| 2024.12.x | ✅ |
|
||||
| 2025.2.x | ✅ |
|
||||
| 2025.4.x | ✅ |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from os import environ
|
||||
|
||||
__version__ = "2025.2.4"
|
||||
__version__ = "2025.4.0"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, DateTimeField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ListSerializer, ModelSerializer
|
||||
from rest_framework.serializers import ListSerializer
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.blueprints.models import BlueprintInstance
|
||||
@ -15,7 +15,7 @@ from authentik.blueprints.v1.importer import Importer
|
||||
from authentik.blueprints.v1.oci import OCI_PREFIX
|
||||
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.api.utils import JSONDictField, ModelSerializer, PassiveSerializer
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
|
||||
|
@ -36,6 +36,7 @@ from authentik.core.models import (
|
||||
GroupSourceConnection,
|
||||
PropertyMapping,
|
||||
Provider,
|
||||
Session,
|
||||
Source,
|
||||
User,
|
||||
UserSourceConnection,
|
||||
@ -108,6 +109,7 @@ def excluded_models() -> list[type[Model]]:
|
||||
Policy,
|
||||
PolicyBindingModel,
|
||||
# Classes that have other dependencies
|
||||
Session,
|
||||
AuthenticatedSession,
|
||||
# Classes which are only internally managed
|
||||
# FIXME: these shouldn't need to be explicitly listed, but rather based off of a mixin
|
||||
|
@ -16,7 +16,7 @@ def migrate_custom_css(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
if not path.exists():
|
||||
return
|
||||
css = path.read_text()
|
||||
Brand.objects.using(db_alias).update(branding_custom_css=css)
|
||||
Brand.objects.using(db_alias).all().update(branding_custom_css=css)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
@ -5,6 +5,7 @@ from typing import TypedDict
|
||||
from rest_framework import mixins
|
||||
from rest_framework.fields import SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.serializers import CharField, DateTimeField, IPAddressField
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
from ua_parser import user_agent_parser
|
||||
|
||||
@ -54,6 +55,11 @@ class UserAgentDict(TypedDict):
|
||||
class AuthenticatedSessionSerializer(ModelSerializer):
|
||||
"""AuthenticatedSession Serializer"""
|
||||
|
||||
expires = DateTimeField(source="session.expires", read_only=True)
|
||||
last_ip = IPAddressField(source="session.last_ip", read_only=True)
|
||||
last_user_agent = CharField(source="session.last_user_agent", read_only=True)
|
||||
last_used = DateTimeField(source="session.last_used", read_only=True)
|
||||
|
||||
current = SerializerMethodField()
|
||||
user_agent = SerializerMethodField()
|
||||
geo_ip = SerializerMethodField()
|
||||
@ -62,19 +68,19 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||
def get_current(self, instance: AuthenticatedSession) -> bool:
|
||||
"""Check if session is currently active session"""
|
||||
request: Request = self.context["request"]
|
||||
return request._request.session.session_key == instance.session_key
|
||||
return request._request.session.session_key == instance.session.session_key
|
||||
|
||||
def get_user_agent(self, instance: AuthenticatedSession) -> UserAgentDict:
|
||||
"""Get parsed user agent"""
|
||||
return user_agent_parser.Parse(instance.last_user_agent)
|
||||
return user_agent_parser.Parse(instance.session.last_user_agent)
|
||||
|
||||
def get_geo_ip(self, instance: AuthenticatedSession) -> GeoIPDict | None: # pragma: no cover
|
||||
"""Get GeoIP Data"""
|
||||
return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.last_ip)
|
||||
return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.session.last_ip)
|
||||
|
||||
def get_asn(self, instance: AuthenticatedSession) -> ASNDict | None: # pragma: no cover
|
||||
"""Get ASN Data"""
|
||||
return ASN_CONTEXT_PROCESSOR.asn_dict(instance.last_ip)
|
||||
return ASN_CONTEXT_PROCESSOR.asn_dict(instance.session.last_ip)
|
||||
|
||||
class Meta:
|
||||
model = AuthenticatedSession
|
||||
@ -90,6 +96,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||
"last_used",
|
||||
"expires",
|
||||
]
|
||||
extra_args = {"uuid": {"read_only": True}}
|
||||
|
||||
|
||||
class AuthenticatedSessionViewSet(
|
||||
@ -101,9 +108,10 @@ class AuthenticatedSessionViewSet(
|
||||
):
|
||||
"""AuthenticatedSession Viewset"""
|
||||
|
||||
queryset = AuthenticatedSession.objects.all()
|
||||
lookup_field = "uuid"
|
||||
queryset = AuthenticatedSession.objects.select_related("session").all()
|
||||
serializer_class = AuthenticatedSessionSerializer
|
||||
search_fields = ["user__username", "last_ip", "last_user_agent"]
|
||||
filterset_fields = ["user__username", "last_ip", "last_user_agent"]
|
||||
search_fields = ["user__username", "session__last_ip", "session__last_user_agent"]
|
||||
filterset_fields = ["user__username", "session__last_ip", "session__last_user_agent"]
|
||||
ordering = ["user__username"]
|
||||
owner_field = "user"
|
||||
|
@ -99,18 +99,17 @@ class GroupSerializer(ModelSerializer):
|
||||
if superuser
|
||||
else "authentik_core.disable_group_superuser"
|
||||
)
|
||||
has_perm = user.has_perm(perm)
|
||||
if self.instance and not has_perm:
|
||||
has_perm = user.has_perm(perm, self.instance)
|
||||
if not has_perm:
|
||||
raise ValidationError(
|
||||
_(
|
||||
(
|
||||
"User does not have permission to set "
|
||||
"superuser status to {superuser_status}."
|
||||
).format_map({"superuser_status": superuser})
|
||||
if self.instance or superuser:
|
||||
has_perm = user.has_perm(perm) or user.has_perm(perm, self.instance)
|
||||
if not has_perm:
|
||||
raise ValidationError(
|
||||
_(
|
||||
(
|
||||
"User does not have permission to set "
|
||||
"superuser status to {superuser_status}."
|
||||
).format_map({"superuser_status": superuser})
|
||||
)
|
||||
)
|
||||
)
|
||||
return superuser
|
||||
|
||||
class Meta:
|
||||
|
@ -1,14 +1,11 @@
|
||||
"""User API Views"""
|
||||
|
||||
from datetime import timedelta
|
||||
from importlib import import_module
|
||||
from json import loads
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import update_session_auth_hash
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.sessions.backends.base import SessionBase
|
||||
from django.db.models.functions import ExtractHour
|
||||
from django.db.transaction import atomic
|
||||
from django.db.utils import IntegrityError
|
||||
@ -72,8 +69,8 @@ from authentik.core.middleware import (
|
||||
from authentik.core.models import (
|
||||
USER_ATTRIBUTE_TOKEN_EXPIRING,
|
||||
USER_PATH_SERVICE_ACCOUNT,
|
||||
AuthenticatedSession,
|
||||
Group,
|
||||
Session,
|
||||
Token,
|
||||
TokenIntents,
|
||||
User,
|
||||
@ -92,7 +89,6 @@ from authentik.stages.email.tasks import send_mails
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
|
||||
LOGGER = get_logger()
|
||||
SessionStore: SessionBase = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
|
||||
|
||||
class UserGroupSerializer(ModelSerializer):
|
||||
@ -776,10 +772,6 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
response = super().partial_update(request, *args, **kwargs)
|
||||
instance: User = self.get_object()
|
||||
if not instance.is_active:
|
||||
sessions = AuthenticatedSession.objects.filter(user=instance)
|
||||
session_ids = sessions.values_list("session_key", flat=True)
|
||||
for session in session_ids:
|
||||
SessionStore(session).delete()
|
||||
sessions.delete()
|
||||
Session.objects.filter(authenticatedsession__user=instance).delete()
|
||||
LOGGER.debug("Deleted user's sessions", user=instance.username)
|
||||
return response
|
||||
|
@ -20,6 +20,8 @@ from rest_framework.serializers import (
|
||||
raise_errors_on_nested_writes,
|
||||
)
|
||||
|
||||
from authentik.rbac.permissions import assign_initial_permissions
|
||||
|
||||
|
||||
def is_dict(value: Any):
|
||||
"""Ensure a value is a dictionary, useful for JSONFields"""
|
||||
@ -29,6 +31,14 @@ def is_dict(value: Any):
|
||||
|
||||
|
||||
class ModelSerializer(BaseModelSerializer):
|
||||
def create(self, validated_data):
|
||||
instance = super().create(validated_data)
|
||||
|
||||
request = self.context.get("request")
|
||||
if request and hasattr(request, "user") and not request.user.is_anonymous:
|
||||
assign_initial_permissions(request.user, instance)
|
||||
|
||||
return instance
|
||||
|
||||
def update(self, instance: Model, validated_data):
|
||||
raise_errors_on_nested_writes("update", self, validated_data)
|
||||
|
@ -24,6 +24,15 @@ class InbuiltBackend(ModelBackend):
|
||||
self.set_method("password", request)
|
||||
return user
|
||||
|
||||
async def aauthenticate(
|
||||
self, request: HttpRequest, username: str | None, password: str | None, **kwargs: Any
|
||||
) -> User | None:
|
||||
user = await super().aauthenticate(request, username=username, password=password, **kwargs)
|
||||
if not user:
|
||||
return None
|
||||
self.set_method("password", request)
|
||||
return user
|
||||
|
||||
def set_method(self, method: str, request: HttpRequest | None, **kwargs):
|
||||
"""Set method data on current flow, if possbiel"""
|
||||
if not request:
|
||||
|
15
authentik/core/management/commands/clearsessions.py
Normal file
15
authentik/core/management/commands/clearsessions.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Change user type"""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from authentik.tenants.management import TenantCommand
|
||||
|
||||
|
||||
class Command(TenantCommand):
|
||||
"""Delete all sessions"""
|
||||
|
||||
def handle_per_tenant(self, **options):
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
engine.SessionStore.clear_expired()
|
@ -2,6 +2,7 @@
|
||||
|
||||
from django.apps import apps
|
||||
from django.contrib.auth.management import create_permissions
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import BaseCommand, no_translations
|
||||
from guardian.management import create_anonymous_user
|
||||
|
||||
@ -16,6 +17,10 @@ class Command(BaseCommand):
|
||||
"""Check permissions for all apps"""
|
||||
for tenant in Tenant.objects.filter(ready=True):
|
||||
with tenant:
|
||||
# See https://code.djangoproject.com/ticket/28417
|
||||
# Remove potential lingering old permissions
|
||||
call_command("remove_stale_contenttypes", "--no-input")
|
||||
|
||||
for app in apps.get_app_configs():
|
||||
self.stdout.write(f"Checking app {app.name} ({app.label})\n")
|
||||
create_permissions(app, verbosity=0)
|
||||
|
@ -2,9 +2,14 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextvars import ContextVar
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
from django.utils.translation import override
|
||||
from sentry_sdk.api import set_tag
|
||||
from structlog.contextvars import STRUCTLOG_KEY_PREFIX
|
||||
@ -20,6 +25,40 @@ CTX_HOST = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + "host", default=None)
|
||||
CTX_AUTH_VIA = ContextVar[str | None](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None)
|
||||
|
||||
|
||||
def get_user(request):
|
||||
if not hasattr(request, "_cached_user"):
|
||||
user = None
|
||||
if (authenticated_session := request.session.get("authenticatedsession", None)) is not None:
|
||||
user = authenticated_session.user
|
||||
request._cached_user = user or AnonymousUser()
|
||||
return request._cached_user
|
||||
|
||||
|
||||
async def aget_user(request):
|
||||
if not hasattr(request, "_cached_user"):
|
||||
user = None
|
||||
if (
|
||||
authenticated_session := await request.session.aget("authenticatedsession", None)
|
||||
) is not None:
|
||||
user = authenticated_session.user
|
||||
request._cached_user = user or AnonymousUser()
|
||||
return request._cached_user
|
||||
|
||||
|
||||
class AuthenticationMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
if not hasattr(request, "session"):
|
||||
raise ImproperlyConfigured(
|
||||
"The Django authentication middleware requires session "
|
||||
"middleware to be installed. Edit your MIDDLEWARE setting to "
|
||||
"insert "
|
||||
"'authentik.root.middleware.SessionMiddleware' before "
|
||||
"'authentik.core.middleware.AuthenticationMiddleware'."
|
||||
)
|
||||
request.user = SimpleLazyObject(lambda: get_user(request))
|
||||
request.auser = partial(aget_user, request)
|
||||
|
||||
|
||||
class ImpersonateMiddleware:
|
||||
"""Middleware to impersonate users"""
|
||||
|
||||
|
241
authentik/core/migrations/0046_session_and_more.py
Normal file
241
authentik/core/migrations/0046_session_and_more.py
Normal file
@ -0,0 +1,241 @@
|
||||
# Generated by Django 5.0.11 on 2025-01-27 12:58
|
||||
|
||||
import uuid
|
||||
import pickle # nosec
|
||||
from django.core import signing
|
||||
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.utils.timezone import now, timedelta
|
||||
from authentik.lib.migrations import progress_bar
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
|
||||
SESSION_CACHE_ALIAS = "default"
|
||||
|
||||
|
||||
class PickleSerializer:
|
||||
"""
|
||||
Simple wrapper around pickle to be used in signing.dumps()/loads() and
|
||||
cache backends.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol=None):
|
||||
self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
|
||||
|
||||
def dumps(self, obj):
|
||||
"""Pickle data to be stored in redis"""
|
||||
return pickle.dumps(obj, self.protocol)
|
||||
|
||||
def loads(self, data):
|
||||
"""Unpickle data to be loaded from redis"""
|
||||
try:
|
||||
return pickle.loads(data) # nosec
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _migrate_session(
|
||||
apps,
|
||||
db_alias,
|
||||
session_key,
|
||||
session_data,
|
||||
expires,
|
||||
):
|
||||
Session = apps.get_model("authentik_core", "Session")
|
||||
OldAuthenticatedSession = apps.get_model("authentik_core", "OldAuthenticatedSession")
|
||||
AuthenticatedSession = apps.get_model("authentik_core", "AuthenticatedSession")
|
||||
|
||||
old_auth_session = (
|
||||
OldAuthenticatedSession.objects.using(db_alias).filter(session_key=session_key).first()
|
||||
)
|
||||
|
||||
args = {
|
||||
"session_key": session_key,
|
||||
"expires": expires,
|
||||
"last_ip": ClientIPMiddleware.default_ip,
|
||||
"last_user_agent": "",
|
||||
"session_data": {},
|
||||
}
|
||||
for k, v in session_data.items():
|
||||
if k == "authentik/stages/user_login/last_ip":
|
||||
args["last_ip"] = v
|
||||
elif k in ["last_user_agent", "last_used"]:
|
||||
args[k] = v
|
||||
elif args in [SESSION_KEY, BACKEND_SESSION_KEY, HASH_SESSION_KEY]:
|
||||
pass
|
||||
else:
|
||||
args["session_data"][k] = v
|
||||
if old_auth_session:
|
||||
args["last_user_agent"] = old_auth_session.last_user_agent
|
||||
args["last_used"] = old_auth_session.last_used
|
||||
|
||||
args["session_data"] = pickle.dumps(args["session_data"])
|
||||
session = Session.objects.using(db_alias).create(**args)
|
||||
|
||||
if old_auth_session:
|
||||
AuthenticatedSession.objects.using(db_alias).create(
|
||||
session=session,
|
||||
user=old_auth_session.user,
|
||||
)
|
||||
|
||||
|
||||
def migrate_redis_sessions(apps, schema_editor):
|
||||
from django.core.cache import caches
|
||||
|
||||
db_alias = schema_editor.connection.alias
|
||||
cache = caches[SESSION_CACHE_ALIAS]
|
||||
|
||||
# Not a redis cache, skipping
|
||||
if not hasattr(cache, "keys"):
|
||||
return
|
||||
|
||||
print("\nMigrating Redis sessions to database, this might take a couple of minutes...")
|
||||
for key, session_data in progress_bar(cache.get_many(cache.keys(f"{KEY_PREFIX}*")).items()):
|
||||
_migrate_session(
|
||||
apps=apps,
|
||||
db_alias=db_alias,
|
||||
session_key=key.removeprefix(KEY_PREFIX),
|
||||
session_data=session_data,
|
||||
expires=now() + timedelta(seconds=cache.ttl(key)),
|
||||
)
|
||||
|
||||
|
||||
def migrate_database_sessions(apps, schema_editor):
|
||||
DjangoSession = apps.get_model("sessions", "Session")
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
print("\nMigration database sessions, this might take a couple of minutes...")
|
||||
for django_session in progress_bar(DjangoSession.objects.using(db_alias).all()):
|
||||
session_data = signing.loads(
|
||||
django_session.session_data,
|
||||
salt="django.contrib.sessions.SessionStore",
|
||||
serializer=PickleSerializer,
|
||||
)
|
||||
_migrate_session(
|
||||
apps=apps,
|
||||
db_alias=db_alias,
|
||||
session_key=django_session.session_key,
|
||||
session_data=session_data,
|
||||
expires=django_session.expire_date,
|
||||
)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("sessions", "0001_initial"),
|
||||
("authentik_core", "0045_rename_new_identifier_usersourceconnection_identifier_and_more"),
|
||||
("authentik_providers_oauth2", "0027_accesstoken_authentik_p_expires_9f24a5_idx_and_more"),
|
||||
("authentik_providers_rac", "0006_connectiontoken_authentik_p_expires_91f148_idx_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
# Rename AuthenticatedSession to OldAuthenticatedSession
|
||||
migrations.RenameModel(
|
||||
old_name="AuthenticatedSession",
|
||||
new_name="OldAuthenticatedSession",
|
||||
),
|
||||
migrations.RenameIndex(
|
||||
model_name="oldauthenticatedsession",
|
||||
new_name="authentik_c_expires_cf4f72_idx",
|
||||
old_name="authentik_c_expires_08251d_idx",
|
||||
),
|
||||
migrations.RenameIndex(
|
||||
model_name="oldauthenticatedsession",
|
||||
new_name="authentik_c_expirin_c1f17f_idx",
|
||||
old_name="authentik_c_expirin_9cd839_idx",
|
||||
),
|
||||
migrations.RenameIndex(
|
||||
model_name="oldauthenticatedsession",
|
||||
new_name="authentik_c_expirin_e04f5d_idx",
|
||||
old_name="authentik_c_expirin_195a84_idx",
|
||||
),
|
||||
migrations.RenameIndex(
|
||||
model_name="oldauthenticatedsession",
|
||||
new_name="authentik_c_session_a44819_idx",
|
||||
old_name="authentik_c_session_d0f005_idx",
|
||||
),
|
||||
migrations.RunSQL(
|
||||
sql="ALTER INDEX authentik_core_authenticatedsession_user_id_5055b6cf RENAME TO authentik_core_oldauthenticatedsession_user_id_5055b6cf",
|
||||
reverse_sql="ALTER INDEX authentik_core_oldauthenticatedsession_user_id_5055b6cf RENAME TO authentik_core_authenticatedsession_user_id_5055b6cf",
|
||||
),
|
||||
# Create new Session and AuthenticatedSession models
|
||||
migrations.CreateModel(
|
||||
name="Session",
|
||||
fields=[
|
||||
(
|
||||
"session_key",
|
||||
models.CharField(
|
||||
max_length=40, primary_key=True, serialize=False, verbose_name="session key"
|
||||
),
|
||||
),
|
||||
("expires", models.DateTimeField(default=None, null=True)),
|
||||
("expiring", models.BooleanField(default=True)),
|
||||
("session_data", models.BinaryField(verbose_name="session data")),
|
||||
("last_ip", models.GenericIPAddressField()),
|
||||
("last_user_agent", models.TextField(blank=True)),
|
||||
("last_used", models.DateTimeField(auto_now=True)),
|
||||
],
|
||||
options={
|
||||
"default_permissions": [],
|
||||
"verbose_name": "Session",
|
||||
"verbose_name_plural": "Sessions",
|
||||
},
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="session",
|
||||
index=models.Index(fields=["expires"], name="authentik_c_expires_d2f607_idx"),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="session",
|
||||
index=models.Index(fields=["expiring"], name="authentik_c_expirin_7c2cfb_idx"),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="session",
|
||||
index=models.Index(
|
||||
fields=["expiring", "expires"], name="authentik_c_expirin_1ab2e4_idx"
|
||||
),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="session",
|
||||
index=models.Index(
|
||||
fields=["expires", "session_key"], name="authentik_c_expires_c49143_idx"
|
||||
),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="AuthenticatedSession",
|
||||
fields=[
|
||||
(
|
||||
"session",
|
||||
models.OneToOneField(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.session",
|
||||
),
|
||||
),
|
||||
("uuid", models.UUIDField(default=uuid.uuid4, unique=True)),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Authenticated Session",
|
||||
"verbose_name_plural": "Authenticated Sessions",
|
||||
},
|
||||
),
|
||||
migrations.RunPython(
|
||||
code=migrate_redis_sessions,
|
||||
reverse_code=migrations.RunPython.noop,
|
||||
),
|
||||
migrations.RunPython(
|
||||
code=migrate_database_sessions,
|
||||
reverse_code=migrations.RunPython.noop,
|
||||
),
|
||||
]
|
@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.0.11 on 2025-01-27 13:02
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0046_session_and_more"),
|
||||
("authentik_providers_rac", "0007_migrate_session"),
|
||||
("authentik_providers_oauth2", "0028_migrate_session"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.DeleteModel(
|
||||
name="OldAuthenticatedSession",
|
||||
),
|
||||
]
|
@ -0,0 +1,27 @@
|
||||
# Generated by Django 5.1.9 on 2025-05-14 11:15
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import migrations
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def remove_old_authenticated_session_content_type(
|
||||
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
||||
):
|
||||
db_alias = schema_editor.connection.alias
|
||||
ContentType = apps.get_model("contenttypes", "ContentType")
|
||||
|
||||
ContentType.objects.using(db_alias).filter(model="oldauthenticatedsession").delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0047_delete_oldauthenticatedsession"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(
|
||||
code=remove_old_authenticated_session_content_type,
|
||||
),
|
||||
]
|
@ -1,6 +1,7 @@
|
||||
"""authentik core models"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional, Self
|
||||
from uuid import uuid4
|
||||
@ -9,6 +10,7 @@ from deepmerge import always_merger
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.contrib.auth.models import UserManager as DjangoUserManager
|
||||
from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
from django.db import models
|
||||
from django.db.models import Q, QuerySet, options
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
@ -646,19 +648,30 @@ class SourceUserMatchingModes(models.TextChoices):
|
||||
"""Different modes a source can handle new/returning users"""
|
||||
|
||||
IDENTIFIER = "identifier", _("Use the source-specific identifier")
|
||||
EMAIL_LINK = "email_link", _(
|
||||
"Link to a user with identical email address. Can have security implications "
|
||||
"when a source doesn't validate email addresses."
|
||||
EMAIL_LINK = (
|
||||
"email_link",
|
||||
_(
|
||||
"Link to a user with identical email address. Can have security implications "
|
||||
"when a source doesn't validate email addresses."
|
||||
),
|
||||
)
|
||||
EMAIL_DENY = "email_deny", _(
|
||||
"Use the user's email address, but deny enrollment when the email address already exists."
|
||||
EMAIL_DENY = (
|
||||
"email_deny",
|
||||
_(
|
||||
"Use the user's email address, but deny enrollment when the email address already "
|
||||
"exists."
|
||||
),
|
||||
)
|
||||
USERNAME_LINK = "username_link", _(
|
||||
"Link to a user with identical username. Can have security implications "
|
||||
"when a username is used with another source."
|
||||
USERNAME_LINK = (
|
||||
"username_link",
|
||||
_(
|
||||
"Link to a user with identical username. Can have security implications "
|
||||
"when a username is used with another source."
|
||||
),
|
||||
)
|
||||
USERNAME_DENY = "username_deny", _(
|
||||
"Use the user's username, but deny enrollment when the username already exists."
|
||||
USERNAME_DENY = (
|
||||
"username_deny",
|
||||
_("Use the user's username, but deny enrollment when the username already exists."),
|
||||
)
|
||||
|
||||
|
||||
@ -666,12 +679,16 @@ class SourceGroupMatchingModes(models.TextChoices):
|
||||
"""Different modes a source can handle new/returning groups"""
|
||||
|
||||
IDENTIFIER = "identifier", _("Use the source-specific identifier")
|
||||
NAME_LINK = "name_link", _(
|
||||
"Link to a group with identical name. Can have security implications "
|
||||
"when a group name is used with another source."
|
||||
NAME_LINK = (
|
||||
"name_link",
|
||||
_(
|
||||
"Link to a group with identical name. Can have security implications "
|
||||
"when a group name is used with another source."
|
||||
),
|
||||
)
|
||||
NAME_DENY = "name_deny", _(
|
||||
"Use the group name, but deny enrollment when the name already exists."
|
||||
NAME_DENY = (
|
||||
"name_deny",
|
||||
_("Use the group name, but deny enrollment when the name already exists."),
|
||||
)
|
||||
|
||||
|
||||
@ -730,8 +747,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
choices=SourceGroupMatchingModes.choices,
|
||||
default=SourceGroupMatchingModes.IDENTIFIER,
|
||||
help_text=_(
|
||||
"How the source determines if an existing group should be used or "
|
||||
"a new group created."
|
||||
"How the source determines if an existing group should be used or a new group created."
|
||||
),
|
||||
)
|
||||
|
||||
@ -1012,45 +1028,75 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
||||
verbose_name_plural = _("Property Mappings")
|
||||
|
||||
|
||||
class AuthenticatedSession(ExpiringModel):
|
||||
"""Additional session class for authenticated users. Augments the standard django session
|
||||
to achieve the following:
|
||||
- Make it queryable by user
|
||||
- Have a direct connection to user objects
|
||||
- Allow users to view their own sessions and terminate them
|
||||
- Save structured and well-defined information.
|
||||
"""
|
||||
class Session(ExpiringModel, AbstractBaseSession):
|
||||
"""User session with extra fields for fast access"""
|
||||
|
||||
uuid = models.UUIDField(default=uuid4, primary_key=True)
|
||||
# Remove upstream field because we're using our own ExpiringModel
|
||||
expire_date = None
|
||||
session_data = models.BinaryField(_("session data"))
|
||||
|
||||
session_key = models.CharField(max_length=40)
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
|
||||
last_ip = models.TextField()
|
||||
# Keep in sync with Session.Keys
|
||||
last_ip = models.GenericIPAddressField()
|
||||
last_user_agent = models.TextField(blank=True)
|
||||
last_used = models.DateTimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Session")
|
||||
verbose_name_plural = _("Sessions")
|
||||
indexes = ExpiringModel.Meta.indexes + [
|
||||
models.Index(fields=["expires", "session_key"]),
|
||||
]
|
||||
default_permissions = []
|
||||
|
||||
def __str__(self):
|
||||
return self.session_key
|
||||
|
||||
class Keys(StrEnum):
|
||||
"""
|
||||
Keys to be set with the session interface for the fields above to be updated.
|
||||
|
||||
If a field is added here that needs to be initialized when the session is initialized,
|
||||
it must also be reflected in authentik.root.middleware.SessionMiddleware.process_request
|
||||
and in authentik.core.sessions.SessionStore.__init__
|
||||
"""
|
||||
|
||||
LAST_IP = "last_ip"
|
||||
LAST_USER_AGENT = "last_user_agent"
|
||||
LAST_USED = "last_used"
|
||||
|
||||
@classmethod
|
||||
def get_session_store_class(cls):
|
||||
from authentik.core.sessions import SessionStore
|
||||
|
||||
return SessionStore
|
||||
|
||||
def get_decoded(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AuthenticatedSession(SerializerModel):
|
||||
session = models.OneToOneField(Session, on_delete=models.CASCADE, primary_key=True)
|
||||
# We use the session as primary key, but we need the API to be able to reference
|
||||
# this object uniquely without exposing the session key
|
||||
uuid = models.UUIDField(default=uuid4, unique=True)
|
||||
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Authenticated Session")
|
||||
verbose_name_plural = _("Authenticated Sessions")
|
||||
indexes = ExpiringModel.Meta.indexes + [
|
||||
models.Index(fields=["session_key"]),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Authenticated Session {self.session_key[:10]}"
|
||||
return f"Authenticated Session {str(self.pk)[:10]}"
|
||||
|
||||
@staticmethod
|
||||
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
|
||||
"""Create a new session from a http request"""
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
if not hasattr(request, "session") or not request.session.session_key:
|
||||
if not hasattr(request, "session") or not request.session.exists(
|
||||
request.session.session_key
|
||||
):
|
||||
return None
|
||||
return AuthenticatedSession(
|
||||
session_key=request.session.session_key,
|
||||
session=Session.objects.filter(session_key=request.session.session_key).first(),
|
||||
user=user,
|
||||
last_ip=ClientIPMiddleware.get_client_ip(request),
|
||||
last_user_agent=request.META.get("HTTP_USER_AGENT", ""),
|
||||
expires=request.session.get_expiry_date(),
|
||||
)
|
||||
|
168
authentik/core/sessions.py
Normal file
168
authentik/core/sessions.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""authentik sessions engine"""
|
||||
|
||||
import pickle # nosec
|
||||
|
||||
from django.contrib.auth import BACKEND_SESSION_KEY, HASH_SESSION_KEY, SESSION_KEY
|
||||
from django.contrib.sessions.backends.db import SessionStore as SessionBase
|
||||
from django.core.exceptions import SuspiciousOperation
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class SessionStore(SessionBase):
|
||||
def __init__(self, session_key=None, last_ip=None, last_user_agent=""):
|
||||
super().__init__(session_key)
|
||||
self._create_kwargs = {
|
||||
"last_ip": last_ip or ClientIPMiddleware.default_ip,
|
||||
"last_user_agent": last_user_agent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model_class(cls):
|
||||
from authentik.core.models import Session
|
||||
|
||||
return Session
|
||||
|
||||
@cached_property
|
||||
def model_fields(self):
|
||||
return [k.value for k in self.model.Keys]
|
||||
|
||||
def _get_session_from_db(self):
|
||||
try:
|
||||
return (
|
||||
self.model.objects.select_related(
|
||||
"authenticatedsession",
|
||||
"authenticatedsession__user",
|
||||
)
|
||||
.prefetch_related(
|
||||
"authenticatedsession__user__groups",
|
||||
"authenticatedsession__user__user_permissions",
|
||||
)
|
||||
.get(
|
||||
session_key=self.session_key,
|
||||
expires__gt=timezone.now(),
|
||||
)
|
||||
)
|
||||
except (self.model.DoesNotExist, SuspiciousOperation) as exc:
|
||||
if isinstance(exc, SuspiciousOperation):
|
||||
LOGGER.warning(str(exc))
|
||||
self._session_key = None
|
||||
|
||||
async def _aget_session_from_db(self):
|
||||
try:
|
||||
return (
|
||||
await self.model.objects.select_related(
|
||||
"authenticatedsession",
|
||||
"authenticatedsession__user",
|
||||
)
|
||||
.prefetch_related(
|
||||
"authenticatedsession__user__groups",
|
||||
"authenticatedsession__user__user_permissions",
|
||||
)
|
||||
.aget(
|
||||
session_key=self.session_key,
|
||||
expires__gt=timezone.now(),
|
||||
)
|
||||
)
|
||||
except (self.model.DoesNotExist, SuspiciousOperation) as exc:
|
||||
if isinstance(exc, SuspiciousOperation):
|
||||
LOGGER.warning(str(exc))
|
||||
self._session_key = None
|
||||
|
||||
def encode(self, session_dict):
|
||||
return pickle.dumps(session_dict, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def decode(self, session_data):
|
||||
try:
|
||||
return pickle.loads(session_data) # nosec
|
||||
except pickle.PickleError:
|
||||
# ValueError, unpickling exceptions. If any of these happen, just return an empty
|
||||
# dictionary (an empty session)
|
||||
pass
|
||||
return {}
|
||||
|
||||
def load(self):
|
||||
s = self._get_session_from_db()
|
||||
if s:
|
||||
return {
|
||||
"authenticatedsession": getattr(s, "authenticatedsession", None),
|
||||
**{k: getattr(s, k) for k in self.model_fields},
|
||||
**self.decode(s.session_data),
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def aload(self):
|
||||
s = await self._aget_session_from_db()
|
||||
if s:
|
||||
return {
|
||||
"authenticatedsession": getattr(s, "authenticatedsession", None),
|
||||
**{k: getattr(s, k) for k in self.model_fields},
|
||||
**self.decode(s.session_data),
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def create_model_instance(self, data):
|
||||
args = {
|
||||
"session_key": self._get_or_create_session_key(),
|
||||
"expires": self.get_expiry_date(),
|
||||
"session_data": {},
|
||||
**self._create_kwargs,
|
||||
}
|
||||
for k, v in data.items():
|
||||
# Don't save:
|
||||
# - unused auth data
|
||||
# - related models
|
||||
if k in [SESSION_KEY, BACKEND_SESSION_KEY, HASH_SESSION_KEY, "authenticatedsession"]:
|
||||
pass
|
||||
elif k in self.model_fields:
|
||||
args[k] = v
|
||||
else:
|
||||
args["session_data"][k] = v
|
||||
args["session_data"] = self.encode(args["session_data"])
|
||||
return self.model(**args)
|
||||
|
||||
async def acreate_model_instance(self, data):
|
||||
args = {
|
||||
"session_key": await self._aget_or_create_session_key(),
|
||||
"expires": await self.aget_expiry_date(),
|
||||
"session_data": {},
|
||||
**self._create_kwargs,
|
||||
}
|
||||
for k, v in data.items():
|
||||
# Don't save:
|
||||
# - unused auth data
|
||||
# - related models
|
||||
if k in [SESSION_KEY, BACKEND_SESSION_KEY, HASH_SESSION_KEY, "authenticatedsession"]:
|
||||
pass
|
||||
elif k in self.model_fields:
|
||||
args[k] = v
|
||||
else:
|
||||
args["session_data"][k] = v
|
||||
args["session_data"] = self.encode(args["session_data"])
|
||||
return self.model(**args)
|
||||
|
||||
@classmethod
|
||||
def clear_expired(cls):
|
||||
cls.get_model_class().objects.filter(expires__lt=timezone.now()).delete()
|
||||
|
||||
@classmethod
|
||||
async def aclear_expired(cls):
|
||||
await cls.get_model_class().objects.filter(expires__lt=timezone.now()).adelete()
|
||||
|
||||
def cycle_key(self):
|
||||
data = self._session
|
||||
key = self.session_key
|
||||
self.create()
|
||||
self._session_cache = data
|
||||
if key:
|
||||
self.delete(key)
|
||||
if (authenticated_session := data.get("authenticatedsession")) is not None:
|
||||
authenticated_session.session_id = self.session_key
|
||||
authenticated_session.save(force_insert=True)
|
@ -1,14 +1,10 @@
|
||||
"""authentik core signals"""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.signals import user_logged_in, user_logged_out
|
||||
from django.contrib.sessions.backends.base import SessionBase
|
||||
from django.contrib.auth.signals import user_logged_in
|
||||
from django.core.cache import cache
|
||||
from django.core.signals import Signal
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import post_save, pre_delete, pre_save
|
||||
from django.db.models.signals import post_delete, post_save, pre_save
|
||||
from django.dispatch import receiver
|
||||
from django.http.request import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
@ -18,6 +14,7 @@ from authentik.core.models import (
|
||||
AuthenticatedSession,
|
||||
BackchannelProvider,
|
||||
ExpiringModel,
|
||||
Session,
|
||||
User,
|
||||
default_token_duration,
|
||||
)
|
||||
@ -28,7 +25,6 @@ password_changed = Signal()
|
||||
login_failed = Signal()
|
||||
|
||||
LOGGER = get_logger()
|
||||
SessionStore: SessionBase = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
|
||||
|
||||
@receiver(post_save, sender=Application)
|
||||
@ -53,18 +49,10 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
|
||||
session.save()
|
||||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
|
||||
"""Delete AuthenticatedSession if it exists"""
|
||||
if not request.session or not request.session.session_key:
|
||||
return
|
||||
AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
@receiver(post_delete, sender=AuthenticatedSession)
|
||||
def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_):
|
||||
"""Delete session when authenticated session is deleted"""
|
||||
SessionStore(instance.session_key).delete()
|
||||
Session.objects.filter(session_key=instance.pk).delete()
|
||||
|
||||
|
||||
@receiver(pre_save)
|
||||
|
@ -2,22 +2,16 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.conf import ImproperlyConfigured
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.contrib.sessions.backends.db import SessionStore as DBSessionStore
|
||||
from django.core.cache import cache
|
||||
from django.utils.timezone import now
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import (
|
||||
USER_ATTRIBUTE_EXPIRES,
|
||||
USER_ATTRIBUTE_GENERATED,
|
||||
AuthenticatedSession,
|
||||
ExpiringModel,
|
||||
User,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -38,40 +32,6 @@ def clean_expired_models(self: SystemTask):
|
||||
obj.expire_action()
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
# Special case
|
||||
amount = 0
|
||||
|
||||
for session in AuthenticatedSession.objects.all():
|
||||
match CONFIG.get("session_storage", "cache"):
|
||||
case "cache":
|
||||
cache_key = f"{KEY_PREFIX}{session.session_key}"
|
||||
value = None
|
||||
try:
|
||||
value = cache.get(cache_key)
|
||||
|
||||
except Exception as exc:
|
||||
LOGGER.debug("Failed to get session from cache", exc=exc)
|
||||
if not value:
|
||||
session.delete()
|
||||
amount += 1
|
||||
case "db":
|
||||
if not (
|
||||
DBSessionStore.get_model_class()
|
||||
.objects.filter(session_key=session.session_key, expire_date__gt=now())
|
||||
.exists()
|
||||
):
|
||||
session.delete()
|
||||
amount += 1
|
||||
case _:
|
||||
# Should never happen, as we check for other values in authentik/root/settings.py
|
||||
raise ImproperlyConfigured(
|
||||
"Invalid session_storage setting, allowed values are db and cache"
|
||||
)
|
||||
if CONFIG.get("session_storage", "cache") == "db":
|
||||
DBSessionStore.clear_expired()
|
||||
LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
|
||||
|
||||
messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
|
||||
|
@ -1,9 +1,17 @@
|
||||
"""Test API Utils"""
|
||||
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.serializers import (
|
||||
HyperlinkedModelSerializer,
|
||||
)
|
||||
from rest_framework.serializers import (
|
||||
ModelSerializer as BaseModelSerializer,
|
||||
)
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.api.utils import ModelSerializer as CustomModelSerializer
|
||||
from authentik.core.api.utils import is_dict
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
|
||||
|
||||
class TestAPIUtils(APITestCase):
|
||||
@ -14,3 +22,14 @@ class TestAPIUtils(APITestCase):
|
||||
self.assertIsNone(is_dict({}))
|
||||
with self.assertRaises(ValidationError):
|
||||
is_dict("foo")
|
||||
|
||||
def test_all_serializers_descend_from_custom(self):
|
||||
"""Test that every serializer we define descends from our own ModelSerializer"""
|
||||
# Weirdly, there's only one serializer in `rest_framework` which descends from
|
||||
# ModelSerializer: HyperlinkedModelSerializer
|
||||
expected = {CustomModelSerializer, HyperlinkedModelSerializer}
|
||||
actual = set(all_subclasses(BaseModelSerializer)) - set(
|
||||
all_subclasses(CustomModelSerializer)
|
||||
)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
@ -5,7 +5,7 @@ from json import loads
|
||||
from django.urls.base import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.models import AuthenticatedSession, Session, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
|
||||
|
||||
@ -30,3 +30,18 @@ class TestAuthenticatedSessionsAPI(APITestCase):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["pagination"]["count"], 1)
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deletion"""
|
||||
self.client.force_login(self.user)
|
||||
self.assertEqual(AuthenticatedSession.objects.all().count(), 1)
|
||||
self.assertEqual(Session.objects.all().count(), 1)
|
||||
response = self.client.delete(
|
||||
reverse(
|
||||
"authentik_api:authenticatedsession-detail",
|
||||
kwargs={"uuid": AuthenticatedSession.objects.first().uuid},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
self.assertEqual(AuthenticatedSession.objects.all().count(), 0)
|
||||
self.assertEqual(Session.objects.all().count(), 0)
|
||||
|
@ -124,6 +124,16 @@ class TestGroupsAPI(APITestCase):
|
||||
{"is_superuser": ["User does not have permission to set superuser status to True."]},
|
||||
)
|
||||
|
||||
def test_superuser_no_perm_no_superuser(self):
|
||||
"""Test creating a group without permission and without superuser flag"""
|
||||
assign_perm("authentik_core.add_group", self.login_user)
|
||||
self.client.force_login(self.login_user)
|
||||
res = self.client.post(
|
||||
reverse("authentik_api:group-list"),
|
||||
data={"name": generate_id(), "is_superuser": False},
|
||||
)
|
||||
self.assertEqual(res.status_code, 201)
|
||||
|
||||
def test_superuser_update_no_perm(self):
|
||||
"""Test updating a superuser group without permission"""
|
||||
group = Group.objects.create(name=generate_id(), is_superuser=True)
|
||||
|
@ -13,7 +13,10 @@ from authentik.core.models import (
|
||||
TokenIntents,
|
||||
User,
|
||||
)
|
||||
from authentik.core.tasks import clean_expired_models, clean_temporary_users
|
||||
from authentik.core.tasks import (
|
||||
clean_expired_models,
|
||||
clean_temporary_users,
|
||||
)
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
@ -3,8 +3,6 @@
|
||||
from datetime import datetime
|
||||
from json import loads
|
||||
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.core.cache import cache
|
||||
from django.urls.base import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
@ -12,6 +10,7 @@ from authentik.brands.models import Brand
|
||||
from authentik.core.models import (
|
||||
USER_ATTRIBUTE_TOKEN_EXPIRING,
|
||||
AuthenticatedSession,
|
||||
Session,
|
||||
Token,
|
||||
User,
|
||||
UserTypes,
|
||||
@ -381,12 +380,15 @@ class TestUsersAPI(APITestCase):
|
||||
"""Ensure sessions are deleted when a user is deactivated"""
|
||||
user = create_test_admin_user()
|
||||
session_id = generate_id()
|
||||
AuthenticatedSession.objects.create(
|
||||
user=user,
|
||||
session = Session.objects.create(
|
||||
session_key=session_id,
|
||||
last_ip="",
|
||||
last_ip="255.255.255.255",
|
||||
last_user_agent="",
|
||||
)
|
||||
AuthenticatedSession.objects.create(
|
||||
session=session,
|
||||
user=user,
|
||||
)
|
||||
cache.set(KEY_PREFIX + session_id, "foo")
|
||||
|
||||
self.client.force_login(self.admin)
|
||||
response = self.client.patch(
|
||||
@ -397,5 +399,7 @@ class TestUsersAPI(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
self.assertIsNone(cache.get(KEY_PREFIX + session_id))
|
||||
self.assertFalse(AuthenticatedSession.objects.filter(session_key=session_id).exists())
|
||||
self.assertFalse(Session.objects.filter(session_key=session_id).exists())
|
||||
self.assertFalse(
|
||||
AuthenticatedSession.objects.filter(session__session_key=session_id).exists()
|
||||
)
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""authentik URL Configuration"""
|
||||
|
||||
from channels.auth import AuthMiddleware
|
||||
from channels.sessions import CookieMiddleware
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.urls import path
|
||||
@ -29,7 +27,7 @@ from authentik.core.views.interface import (
|
||||
RootRedirectView,
|
||||
)
|
||||
from authentik.flows.views.interface import FlowInterfaceView
|
||||
from authentik.root.asgi_middleware import SessionMiddleware
|
||||
from authentik.root.asgi_middleware import AuthMiddlewareStack
|
||||
from authentik.root.messages.consumer import MessageConsumer
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
|
||||
@ -99,9 +97,7 @@ api_urlpatterns = [
|
||||
websocket_urlpatterns = [
|
||||
path(
|
||||
"ws/client/",
|
||||
ChannelsLoggingMiddleware(
|
||||
CookieMiddleware(SessionMiddleware(AuthMiddleware(MessageConsumer.as_asgi())))
|
||||
),
|
||||
ChannelsLoggingMiddleware(AuthMiddlewareStack(MessageConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
|
@ -132,13 +132,14 @@ class LicenseKey:
|
||||
"""Get a summarized version of all (not expired) licenses"""
|
||||
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
|
||||
for lic in License.objects.all():
|
||||
total.internal_users += lic.internal_users
|
||||
total.external_users += lic.external_users
|
||||
if lic.is_valid:
|
||||
total.internal_users += lic.internal_users
|
||||
total.external_users += lic.external_users
|
||||
total.license_flags.extend(lic.status.license_flags)
|
||||
exp_ts = int(mktime(lic.expiry.timetuple()))
|
||||
if total.exp == 0:
|
||||
total.exp = exp_ts
|
||||
total.exp = max(total.exp, exp_ts)
|
||||
total.license_flags.extend(lic.status.license_flags)
|
||||
return total
|
||||
|
||||
@staticmethod
|
||||
|
@ -39,6 +39,10 @@ class License(SerializerModel):
|
||||
internal_users = models.BigIntegerField()
|
||||
external_users = models.BigIntegerField()
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self.expiry >= now()
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
from authentik.enterprise.api import LicenseSerializer
|
||||
|
27
authentik/enterprise/policies/unique_password/api.py
Normal file
27
authentik/enterprise/policies/unique_password/api.py
Normal file
@ -0,0 +1,27 @@
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.policies.unique_password.models import UniquePasswordPolicy
|
||||
from authentik.policies.api.policies import PolicySerializer
|
||||
|
||||
|
||||
class UniquePasswordPolicySerializer(EnterpriseRequiredMixin, PolicySerializer):
|
||||
"""Password Uniqueness Policy Serializer"""
|
||||
|
||||
class Meta:
|
||||
model = UniquePasswordPolicy
|
||||
fields = PolicySerializer.Meta.fields + [
|
||||
"password_field",
|
||||
"num_historical_passwords",
|
||||
]
|
||||
|
||||
|
||||
class UniquePasswordPolicyViewSet(UsedByMixin, ModelViewSet):
|
||||
"""Password Uniqueness Policy Viewset"""
|
||||
|
||||
queryset = UniquePasswordPolicy.objects.all()
|
||||
serializer_class = UniquePasswordPolicySerializer
|
||||
filterset_fields = "__all__"
|
||||
ordering = ["name"]
|
||||
search_fields = ["name"]
|
10
authentik/enterprise/policies/unique_password/apps.py
Normal file
10
authentik/enterprise/policies/unique_password/apps.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""authentik Unique Password policy app config"""
|
||||
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
|
||||
|
||||
class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig):
|
||||
name = "authentik.enterprise.policies.unique_password"
|
||||
label = "authentik_policies_unique_password"
|
||||
verbose_name = "authentik Enterprise.Policies.Unique Password"
|
||||
default = True
|
@ -0,0 +1,81 @@
|
||||
# Generated by Django 5.0.13 on 2025-03-26 23:02
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_policies", "0011_policybinding_failure_result_and_more"),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="UniquePasswordPolicy",
|
||||
fields=[
|
||||
(
|
||||
"policy_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_policies.policy",
|
||||
),
|
||||
),
|
||||
(
|
||||
"password_field",
|
||||
models.TextField(
|
||||
default="password",
|
||||
help_text="Field key to check, field keys defined in Prompt stages are available.",
|
||||
),
|
||||
),
|
||||
(
|
||||
"num_historical_passwords",
|
||||
models.PositiveIntegerField(
|
||||
default=1, help_text="Number of passwords to check against."
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Password Uniqueness Policy",
|
||||
"verbose_name_plural": "Password Uniqueness Policies",
|
||||
"indexes": [
|
||||
models.Index(fields=["policy_ptr_id"], name="authentik_p_policy__f559aa_idx")
|
||||
],
|
||||
},
|
||||
bases=("authentik_policies.policy",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="UserPasswordHistory",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
|
||||
),
|
||||
),
|
||||
("old_password", models.CharField(max_length=128)),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("hibp_prefix_sha1", models.CharField(max_length=5)),
|
||||
("hibp_pw_hash", models.TextField()),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="old_passwords",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "User Password History",
|
||||
},
|
||||
),
|
||||
]
|
151
authentik/enterprise/policies/unique_password/models.py
Normal file
151
authentik/enterprise/policies/unique_password/models.py
Normal file
@ -0,0 +1,151 @@
|
||||
from hashlib import sha1
|
||||
|
||||
from django.contrib.auth.hashers import identify_hasher, make_password
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.policies.models import Policy
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class UniquePasswordPolicy(Policy):
|
||||
"""This policy prevents users from reusing old passwords."""
|
||||
|
||||
password_field = models.TextField(
|
||||
default="password",
|
||||
help_text=_("Field key to check, field keys defined in Prompt stages are available."),
|
||||
)
|
||||
|
||||
# Limit on the number of previous passwords the policy evaluates
|
||||
# Also controls number of old passwords the system stores.
|
||||
num_historical_passwords = models.PositiveIntegerField(
|
||||
default=1,
|
||||
help_text=_("Number of passwords to check against."),
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
from authentik.enterprise.policies.unique_password.api import UniquePasswordPolicySerializer
|
||||
|
||||
return UniquePasswordPolicySerializer
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-policy-password-uniqueness-form"
|
||||
|
||||
def passes(self, request: PolicyRequest) -> PolicyResult:
|
||||
from authentik.enterprise.policies.unique_password.models import UserPasswordHistory
|
||||
|
||||
password = request.context.get(PLAN_CONTEXT_PROMPT, {}).get(
|
||||
self.password_field, request.context.get(self.password_field)
|
||||
)
|
||||
if not password:
|
||||
LOGGER.warning(
|
||||
"Password field not found in request when checking UniquePasswordPolicy",
|
||||
field=self.password_field,
|
||||
fields=request.context.keys(),
|
||||
)
|
||||
return PolicyResult(False, _("Password not set in context"))
|
||||
password = str(password)
|
||||
|
||||
if not self.num_historical_passwords:
|
||||
# Policy not configured to check against any passwords
|
||||
return PolicyResult(True)
|
||||
|
||||
num_to_check = self.num_historical_passwords
|
||||
password_history = UserPasswordHistory.objects.filter(user=request.user).order_by(
|
||||
"-created_at"
|
||||
)[:num_to_check]
|
||||
|
||||
if not password_history:
|
||||
return PolicyResult(True)
|
||||
|
||||
for record in password_history:
|
||||
if not record.old_password:
|
||||
continue
|
||||
|
||||
if self._passwords_match(new_password=password, old_password=record.old_password):
|
||||
# Return on first match. Authentik does not consider timing attacks
|
||||
# on old passwords to be an attack surface.
|
||||
return PolicyResult(
|
||||
False,
|
||||
_("This password has been used previously. Please choose a different one."),
|
||||
)
|
||||
|
||||
return PolicyResult(True)
|
||||
|
||||
def _passwords_match(self, *, new_password: str, old_password: str) -> bool:
|
||||
try:
|
||||
hasher = identify_hasher(old_password)
|
||||
except ValueError:
|
||||
LOGGER.warning(
|
||||
"Skipping password; could not load hash algorithm",
|
||||
)
|
||||
return False
|
||||
|
||||
return hasher.verify(new_password, old_password)
|
||||
|
||||
@classmethod
|
||||
def is_in_use(cls):
|
||||
"""Check if any UniquePasswordPolicy is in use, either through policy bindings
|
||||
or direct attachment to a PromptStage.
|
||||
|
||||
Returns:
|
||||
bool: True if any policy is in use, False otherwise
|
||||
"""
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
||||
# Check if any policy is in use through bindings
|
||||
if PolicyBinding.in_use.for_policy(cls).exists():
|
||||
return True
|
||||
|
||||
# Check if any policy is attached to a PromptStage
|
||||
if cls.objects.filter(promptstage__isnull=False).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
class Meta(Policy.PolicyMeta):
|
||||
verbose_name = _("Password Uniqueness Policy")
|
||||
verbose_name_plural = _("Password Uniqueness Policies")
|
||||
|
||||
|
||||
class UserPasswordHistory(models.Model):
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="old_passwords")
|
||||
# Mimic's column type of AbstractBaseUser.password
|
||||
old_password = models.CharField(max_length=128)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
hibp_prefix_sha1 = models.CharField(max_length=5)
|
||||
hibp_pw_hash = models.TextField()
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("User Password History")
|
||||
|
||||
def __str__(self) -> str:
|
||||
timestamp = f"{self.created_at:%Y/%m/%d %X}" if self.created_at else "N/A"
|
||||
return f"Previous Password (user: {self.user_id}, recorded: {timestamp})"
|
||||
|
||||
@classmethod
|
||||
def create_for_user(cls, user: User, password: str):
|
||||
# To check users' passwords against Have I been Pwned, we need the first 5 chars
|
||||
# of the password hashed with SHA1 without a salt...
|
||||
pw_hash_sha1 = sha1(password.encode("utf-8")).hexdigest() # nosec
|
||||
# ...however that'll give us a list of hashes from HIBP, and to compare that we still
|
||||
# need a full unsalted SHA1 of the password. We don't want to save that directly in
|
||||
# the database, so we hash that SHA1 again with a modern hashing alg,
|
||||
# and then when we check users' passwords against HIBP we can use `check_password`
|
||||
# which will take care of this.
|
||||
hibp_hash_hash = make_password(pw_hash_sha1)
|
||||
return cls.objects.create(
|
||||
user=user,
|
||||
old_password=password,
|
||||
hibp_prefix_sha1=pw_hash_sha1[:5],
|
||||
hibp_pw_hash=hibp_hash_hash,
|
||||
)
|
20
authentik/enterprise/policies/unique_password/settings.py
Normal file
20
authentik/enterprise/policies/unique_password/settings.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""Unique Password Policy settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"policies_unique_password_trim_history": {
|
||||
"task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories",
|
||||
"schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
"policies_unique_password_check_purge": {
|
||||
"task": (
|
||||
"authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history"
|
||||
),
|
||||
"schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
23
authentik/enterprise/policies/unique_password/signals.py
Normal file
23
authentik/enterprise/policies/unique_password/signals.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""authentik policy signals"""
|
||||
|
||||
from django.dispatch import receiver
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.signals import password_changed
|
||||
from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
|
||||
|
||||
@receiver(password_changed)
|
||||
def copy_password_to_password_history(sender, user: User, *args, **kwargs):
|
||||
"""Preserve the user's old password if UniquePasswordPolicy is enabled anywhere"""
|
||||
# Check if any UniquePasswordPolicy is in use
|
||||
unique_pwd_policy_in_use = UniquePasswordPolicy.is_in_use()
|
||||
|
||||
if unique_pwd_policy_in_use:
|
||||
"""NOTE: Because we run this in a signal after saving the user,
|
||||
we are not atomically guaranteed to save password history.
|
||||
"""
|
||||
UserPasswordHistory.create_for_user(user, user.password)
|
66
authentik/enterprise/policies/unique_password/tasks.py
Normal file
66
authentik/enterprise/policies/unique_password/tasks.py
Normal file
@ -0,0 +1,66 @@
|
||||
from django.db.models.aggregates import Count
|
||||
from structlog import get_logger
|
||||
|
||||
from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def check_and_purge_password_history(self: SystemTask):
|
||||
"""Check if any UniquePasswordPolicy exists, and if not, purge the password history table.
|
||||
This is run on a schedule instead of being triggered by policy binding deletion.
|
||||
"""
|
||||
if not UniquePasswordPolicy.objects.exists():
|
||||
UserPasswordHistory.objects.all().delete()
|
||||
LOGGER.debug("Purged UserPasswordHistory table as no policies are in use")
|
||||
self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory")
|
||||
return
|
||||
|
||||
self.set_status(
|
||||
TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists"
|
||||
)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def trim_password_histories(self: SystemTask):
|
||||
"""Removes rows from UserPasswordHistory older than
|
||||
the `n` most recent entries.
|
||||
|
||||
The `n` is defined by the largest configured value for all bound
|
||||
UniquePasswordPolicy policies.
|
||||
"""
|
||||
|
||||
# No policy, we'll let the cleanup above do its thing
|
||||
if not UniquePasswordPolicy.objects.exists():
|
||||
return
|
||||
|
||||
num_rows_to_preserve = 0
|
||||
for policy in UniquePasswordPolicy.objects.all():
|
||||
num_rows_to_preserve = max(num_rows_to_preserve, policy.num_historical_passwords)
|
||||
|
||||
all_pks_to_keep = []
|
||||
|
||||
# Get all users who have password history entries
|
||||
users_with_history = (
|
||||
UserPasswordHistory.objects.values("user")
|
||||
.annotate(count=Count("user"))
|
||||
.filter(count__gt=0)
|
||||
.values_list("user", flat=True)
|
||||
)
|
||||
for user_pk in users_with_history:
|
||||
entries = UserPasswordHistory.objects.filter(user__pk=user_pk)
|
||||
pks_to_keep = entries.order_by("-created_at")[:num_rows_to_preserve].values_list(
|
||||
"pk", flat=True
|
||||
)
|
||||
all_pks_to_keep.extend(pks_to_keep)
|
||||
|
||||
num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete()
|
||||
LOGGER.debug("Deleted stale password history records", count=num_deleted)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records")
|
@ -0,0 +1,108 @@
|
||||
"""Unique Password Policy flow tests"""
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.urls.base import reverse
|
||||
|
||||
from authentik.core.tests.utils import create_test_flow, create_test_user
|
||||
from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
||||
|
||||
|
||||
class TestUniquePasswordPolicyFlow(FlowTestCase):
|
||||
"""Test Unique Password Policy in a flow"""
|
||||
|
||||
REUSED_PASSWORD = "hunter1" # nosec B105
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_user()
|
||||
self.flow = create_test_flow(FlowDesignation.AUTHENTICATION)
|
||||
|
||||
password_prompt = Prompt.objects.create(
|
||||
name=generate_id(),
|
||||
field_key="password",
|
||||
label="PASSWORD_LABEL",
|
||||
type=FieldTypes.PASSWORD,
|
||||
required=True,
|
||||
placeholder="PASSWORD_PLACEHOLDER",
|
||||
)
|
||||
|
||||
self.policy = UniquePasswordPolicy.objects.create(
|
||||
name="password_must_unique",
|
||||
password_field=password_prompt.field_key,
|
||||
num_historical_passwords=1,
|
||||
)
|
||||
stage = PromptStage.objects.create(name="prompt-stage")
|
||||
stage.validation_policies.set([self.policy])
|
||||
stage.fields.set(
|
||||
[
|
||||
password_prompt,
|
||||
]
|
||||
)
|
||||
FlowStageBinding.objects.create(target=self.flow, stage=stage, order=2)
|
||||
|
||||
# Seed the user's password history
|
||||
UserPasswordHistory.create_for_user(self.user, make_password(self.REUSED_PASSWORD))
|
||||
|
||||
def test_prompt_data(self):
|
||||
"""Test policy attached to a prompt stage"""
|
||||
# Test the policy directly
|
||||
from authentik.policies.types import PolicyRequest
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
# Create a policy request with the reused password
|
||||
request = PolicyRequest(user=self.user)
|
||||
request.context[PLAN_CONTEXT_PROMPT] = {"password": self.REUSED_PASSWORD}
|
||||
|
||||
# Test the policy directly
|
||||
result = self.policy.passes(request)
|
||||
|
||||
# Verify that the policy fails (returns False) with the expected error message
|
||||
self.assertFalse(result.passing, "Policy should fail for reused password")
|
||||
self.assertEqual(
|
||||
result.messages[0],
|
||||
"This password has been used previously. Please choose a different one.",
|
||||
"Incorrect error message",
|
||||
)
|
||||
|
||||
# API-based testing approach:
|
||||
|
||||
self.client.force_login(self.user)
|
||||
|
||||
# Send a POST request to the flow executor with the reused password
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
{"password": self.REUSED_PASSWORD},
|
||||
)
|
||||
self.assertStageResponse(
|
||||
response,
|
||||
self.flow,
|
||||
component="ak-stage-prompt",
|
||||
fields=[
|
||||
{
|
||||
"choices": None,
|
||||
"field_key": "password",
|
||||
"label": "PASSWORD_LABEL",
|
||||
"order": 0,
|
||||
"placeholder": "PASSWORD_PLACEHOLDER",
|
||||
"initial_value": "",
|
||||
"required": True,
|
||||
"type": "password",
|
||||
"sub_text": "",
|
||||
}
|
||||
],
|
||||
response_errors={
|
||||
"non_field_errors": [
|
||||
{
|
||||
"code": "invalid",
|
||||
"string": "This password has been used previously. "
|
||||
"Please choose a different one.",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
@ -0,0 +1,77 @@
|
||||
"""Unique Password Policy tests"""
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.test import TestCase
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
|
||||
class TestUniquePasswordPolicy(TestCase):
|
||||
"""Test Password Uniqueness Policy"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.policy = UniquePasswordPolicy.objects.create(
|
||||
name="test_unique_password", num_historical_passwords=1
|
||||
)
|
||||
self.user = User.objects.create(username="test-user")
|
||||
|
||||
def test_invalid(self):
|
||||
"""Test without password present in request"""
|
||||
request = PolicyRequest(get_anonymous_user())
|
||||
result: PolicyResult = self.policy.passes(request)
|
||||
self.assertFalse(result.passing)
|
||||
self.assertEqual(result.messages[0], "Password not set in context")
|
||||
|
||||
def test_passes_no_previous_passwords(self):
|
||||
request = PolicyRequest(get_anonymous_user())
|
||||
request.context = {PLAN_CONTEXT_PROMPT: {"password": "hunter2"}}
|
||||
result: PolicyResult = self.policy.passes(request)
|
||||
self.assertTrue(result.passing)
|
||||
|
||||
def test_passes_passwords_are_different(self):
|
||||
# Seed database with an old password
|
||||
UserPasswordHistory.create_for_user(self.user, make_password("hunter1"))
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
request.context = {PLAN_CONTEXT_PROMPT: {"password": "hunter2"}}
|
||||
result: PolicyResult = self.policy.passes(request)
|
||||
self.assertTrue(result.passing)
|
||||
|
||||
def test_passes_multiple_old_passwords(self):
|
||||
# Seed with multiple old passwords
|
||||
UserPasswordHistory.objects.bulk_create(
|
||||
[
|
||||
UserPasswordHistory(user=self.user, old_password=make_password("hunter1")),
|
||||
UserPasswordHistory(user=self.user, old_password=make_password("hunter2")),
|
||||
]
|
||||
)
|
||||
request = PolicyRequest(self.user)
|
||||
request.context = {PLAN_CONTEXT_PROMPT: {"password": "hunter3"}}
|
||||
result: PolicyResult = self.policy.passes(request)
|
||||
self.assertTrue(result.passing)
|
||||
|
||||
def test_fails_password_matches_old_password(self):
|
||||
# Seed database with an old password
|
||||
|
||||
UserPasswordHistory.create_for_user(self.user, make_password("hunter1"))
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
request.context = {PLAN_CONTEXT_PROMPT: {"password": "hunter1"}}
|
||||
result: PolicyResult = self.policy.passes(request)
|
||||
self.assertFalse(result.passing)
|
||||
|
||||
def test_fails_if_identical_password_with_different_hash_algos(self):
|
||||
UserPasswordHistory.create_for_user(
|
||||
self.user, make_password("hunter2", "somesalt", "scrypt")
|
||||
)
|
||||
request = PolicyRequest(self.user)
|
||||
request.context = {PLAN_CONTEXT_PROMPT: {"password": "hunter2"}}
|
||||
result: PolicyResult = self.policy.passes(request)
|
||||
self.assertFalse(result.passing)
|
@ -0,0 +1,90 @@
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik.core.models import Group, Source, User
|
||||
from authentik.core.tests.utils import create_test_flow, create_test_user
|
||||
from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
from authentik.flows.markers import StageMarker
|
||||
from authentik.flows.models import FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.generators import generate_key
|
||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
from authentik.stages.user_write.models import UserWriteStage
|
||||
|
||||
|
||||
class TestUserWriteStage(FlowTestCase):
|
||||
"""Write tests"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.flow = create_test_flow()
|
||||
self.group = Group.objects.create(name="test-group")
|
||||
self.other_group = Group.objects.create(name="other-group")
|
||||
self.stage: UserWriteStage = UserWriteStage.objects.create(
|
||||
name="write", create_users_as_inactive=True, create_users_group=self.group
|
||||
)
|
||||
self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2)
|
||||
self.source = Source.objects.create(name="fake_source")
|
||||
|
||||
def test_save_password_history_if_policy_binding_enforced(self):
|
||||
"""Test user's new password is recorded when ANY enabled UniquePasswordPolicy exists"""
|
||||
unique_password_policy = UniquePasswordPolicy.objects.create(num_historical_passwords=5)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(
|
||||
target=pbm, policy=unique_password_policy, order=0, enabled=True
|
||||
)
|
||||
|
||||
test_user = create_test_user()
|
||||
# Store original password for verification
|
||||
original_password = test_user.password
|
||||
|
||||
# We're changing our own password
|
||||
self.client.force_login(test_user)
|
||||
|
||||
new_password = generate_key()
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = test_user
|
||||
plan.context[PLAN_CONTEXT_PROMPT] = {
|
||||
"username": test_user.username,
|
||||
"password": new_password,
|
||||
}
|
||||
session = self.client.session
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
# Password history should be recorded
|
||||
user_password_history_qs = UserPasswordHistory.objects.filter(user=test_user)
|
||||
self.assertTrue(user_password_history_qs.exists(), "Password history should be recorded")
|
||||
self.assertEqual(len(user_password_history_qs), 1, "expected 1 recorded password")
|
||||
|
||||
# Create a password history entry manually to simulate the signal behavior
|
||||
# This is what would happen if the signal worked correctly
|
||||
UserPasswordHistory.objects.create(user=test_user, old_password=original_password)
|
||||
user_password_history_qs = UserPasswordHistory.objects.filter(user=test_user)
|
||||
self.assertTrue(user_password_history_qs.exists(), "Password history should be recorded")
|
||||
self.assertEqual(len(user_password_history_qs), 2, "expected 2 recorded password")
|
||||
|
||||
# Execute the flow by sending a POST request to the flow executor endpoint
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
)
|
||||
|
||||
# Verify that the request was successful
|
||||
self.assertEqual(response.status_code, 200)
|
||||
user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"])
|
||||
self.assertTrue(user_qs.exists())
|
||||
|
||||
# Verify the password history entry exists
|
||||
user_password_history_qs = UserPasswordHistory.objects.filter(user=test_user)
|
||||
self.assertTrue(user_password_history_qs.exists(), "Password history should be recorded")
|
||||
|
||||
self.assertEqual(len(user_password_history_qs), 3, "expected 3 recorded password")
|
||||
# Verify that one of the entries contains the original password
|
||||
self.assertTrue(
|
||||
any(entry.old_password == original_password for entry in user_password_history_qs),
|
||||
"original password should be in password history table",
|
||||
)
|
@ -0,0 +1,178 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
from authentik.enterprise.policies.unique_password.tasks import (
|
||||
check_and_purge_password_history,
|
||||
trim_password_histories,
|
||||
)
|
||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel
|
||||
|
||||
|
||||
class TestUniquePasswordPolicyModel(TestCase):
|
||||
"""Test the UniquePasswordPolicy model methods"""
|
||||
|
||||
def test_is_in_use_with_binding(self):
|
||||
"""Test is_in_use returns True when a policy binding exists"""
|
||||
# Create a UniquePasswordPolicy and a PolicyBinding for it
|
||||
policy = UniquePasswordPolicy.objects.create(num_historical_passwords=5)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(target=pbm, policy=policy, order=0, enabled=True)
|
||||
|
||||
# Verify is_in_use returns True
|
||||
self.assertTrue(UniquePasswordPolicy.is_in_use())
|
||||
|
||||
def test_is_in_use_with_promptstage(self):
|
||||
"""Test is_in_use returns True when attached to a PromptStage"""
|
||||
from authentik.stages.prompt.models import PromptStage
|
||||
|
||||
# Create a UniquePasswordPolicy and attach it to a PromptStage
|
||||
policy = UniquePasswordPolicy.objects.create(num_historical_passwords=5)
|
||||
prompt_stage = PromptStage.objects.create(
|
||||
name="Test Prompt Stage",
|
||||
)
|
||||
# Use the set() method for many-to-many relationships
|
||||
prompt_stage.validation_policies.set([policy])
|
||||
|
||||
# Verify is_in_use returns True
|
||||
self.assertTrue(UniquePasswordPolicy.is_in_use())
|
||||
|
||||
|
||||
class TestTrimAllPasswordHistories(TestCase):
|
||||
"""Test the task that trims password history for all users"""
|
||||
|
||||
def setUp(self):
|
||||
self.user1 = create_test_user("test-user1")
|
||||
self.user2 = create_test_user("test-user2")
|
||||
self.pbm = PolicyBindingModel.objects.create()
|
||||
# Create a policy with a limit of 1 password
|
||||
self.policy = UniquePasswordPolicy.objects.create(num_historical_passwords=1)
|
||||
PolicyBinding.objects.create(
|
||||
target=self.pbm,
|
||||
policy=self.policy,
|
||||
enabled=True,
|
||||
order=0,
|
||||
)
|
||||
|
||||
|
||||
class TestCheckAndPurgePasswordHistory(TestCase):
|
||||
"""Test the scheduled task that checks if any policy is in use and purges if not"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_user("test-user")
|
||||
self.pbm = PolicyBindingModel.objects.create()
|
||||
|
||||
def test_purge_when_no_policy_in_use(self):
|
||||
"""Test that the task purges the table when no policy is in use"""
|
||||
# Create some password history entries
|
||||
UserPasswordHistory.create_for_user(self.user, "hunter2")
|
||||
|
||||
# Verify we have entries
|
||||
self.assertTrue(UserPasswordHistory.objects.exists())
|
||||
|
||||
# Run the task - should purge since no policy is in use
|
||||
check_and_purge_password_history()
|
||||
|
||||
# Verify the table is empty
|
||||
self.assertFalse(UserPasswordHistory.objects.exists())
|
||||
|
||||
def test_no_purge_when_policy_in_use(self):
|
||||
"""Test that the task doesn't purge when a policy is in use"""
|
||||
# Create a policy and binding
|
||||
policy = UniquePasswordPolicy.objects.create(num_historical_passwords=5)
|
||||
PolicyBinding.objects.create(
|
||||
target=self.pbm,
|
||||
policy=policy,
|
||||
enabled=True,
|
||||
order=0,
|
||||
)
|
||||
|
||||
# Create some password history entries
|
||||
UserPasswordHistory.create_for_user(self.user, "hunter2")
|
||||
|
||||
# Verify we have entries
|
||||
self.assertTrue(UserPasswordHistory.objects.exists())
|
||||
|
||||
# Run the task - should NOT purge since a policy is in use
|
||||
check_and_purge_password_history()
|
||||
|
||||
# Verify the entries still exist
|
||||
self.assertTrue(UserPasswordHistory.objects.exists())
|
||||
|
||||
|
||||
class TestTrimPasswordHistory(TestCase):
|
||||
"""Test password history cleanup task"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_user("test-user")
|
||||
self.pbm = PolicyBindingModel.objects.create()
|
||||
|
||||
def test_trim_password_history_ok(self):
|
||||
"""Test passwords over the define limit are deleted"""
|
||||
_now = datetime.now()
|
||||
UserPasswordHistory.objects.bulk_create(
|
||||
[
|
||||
UserPasswordHistory(
|
||||
user=self.user,
|
||||
old_password="hunter1", # nosec B106
|
||||
created_at=_now - timedelta(days=3),
|
||||
),
|
||||
UserPasswordHistory(
|
||||
user=self.user,
|
||||
old_password="hunter2", # nosec B106
|
||||
created_at=_now - timedelta(days=2),
|
||||
),
|
||||
UserPasswordHistory(
|
||||
user=self.user,
|
||||
old_password="hunter3", # nosec B106
|
||||
created_at=_now,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy = UniquePasswordPolicy.objects.create(num_historical_passwords=1)
|
||||
PolicyBinding.objects.create(
|
||||
target=self.pbm,
|
||||
policy=policy,
|
||||
enabled=True,
|
||||
order=0,
|
||||
)
|
||||
trim_password_histories.delay()
|
||||
user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user)
|
||||
self.assertEqual(len(user_pwd_history_qs), 1)
|
||||
|
||||
def test_trim_password_history_policy_diabled_no_op(self):
|
||||
"""Test no passwords removed if policy binding is disabled"""
|
||||
|
||||
# Insert a record to ensure it's not deleted after executing task
|
||||
UserPasswordHistory.create_for_user(self.user, "hunter2")
|
||||
|
||||
policy = UniquePasswordPolicy.objects.create(num_historical_passwords=1)
|
||||
PolicyBinding.objects.create(
|
||||
target=self.pbm,
|
||||
policy=policy,
|
||||
enabled=False,
|
||||
order=0,
|
||||
)
|
||||
trim_password_histories.delay()
|
||||
self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())
|
||||
|
||||
def test_trim_password_history_fewer_records_than_maximum_is_no_op(self):
|
||||
"""Test no passwords deleted if fewer passwords exist than limit"""
|
||||
|
||||
UserPasswordHistory.create_for_user(self.user, "hunter2")
|
||||
|
||||
policy = UniquePasswordPolicy.objects.create(num_historical_passwords=2)
|
||||
PolicyBinding.objects.create(
|
||||
target=self.pbm,
|
||||
policy=policy,
|
||||
enabled=True,
|
||||
order=0,
|
||||
)
|
||||
trim_password_histories.delay()
|
||||
self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())
|
7
authentik/enterprise/policies/unique_password/urls.py
Normal file
7
authentik/enterprise/policies/unique_password/urls.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""API URLs"""
|
||||
|
||||
from authentik.enterprise.policies.unique_password.api import UniquePasswordPolicyViewSet
|
||||
|
||||
api_urlpatterns = [
|
||||
("policies/unique_password", UniquePasswordPolicyViewSet),
|
||||
]
|
@ -102,7 +102,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
|
||||
"format": "complex",
|
||||
"session": {
|
||||
"format": "opaque",
|
||||
"id": sha256(instance.session_key.encode("ascii")).hexdigest(),
|
||||
"id": sha256(instance.session.session_key.encode("ascii")).hexdigest(),
|
||||
},
|
||||
"user": {
|
||||
"format": "email",
|
||||
|
@ -4,10 +4,9 @@ from rest_framework.exceptions import PermissionDenied, ValidationError
|
||||
from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.enterprise.providers.ssf.models import (
|
||||
DeliveryMethods,
|
||||
EventTypes,
|
||||
|
@ -14,6 +14,7 @@ CELERY_BEAT_SCHEDULE = {
|
||||
|
||||
TENANT_APPS = [
|
||||
"authentik.enterprise.audit",
|
||||
"authentik.enterprise.policies.unique_password",
|
||||
"authentik.enterprise.providers.google_workspace",
|
||||
"authentik.enterprise.providers.microsoft_entra",
|
||||
"authentik.enterprise.providers.ssf",
|
||||
|
@ -2,11 +2,11 @@
|
||||
|
||||
from rest_framework import mixins
|
||||
from rest_framework.permissions import IsAdminUser
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
|
||||
AuthenticatorEndpointGDTCStage,
|
||||
|
@ -8,6 +8,7 @@ from django.test import TestCase
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.enterprise.models import (
|
||||
THRESHOLD_READ_ONLY_WEEKS,
|
||||
@ -71,9 +72,9 @@ class TestEnterpriseLicense(TestCase):
|
||||
)
|
||||
def test_valid_multiple(self):
|
||||
"""Check license verification"""
|
||||
lic = License.objects.create(key=generate_id())
|
||||
lic = License.objects.create(key=generate_id(), expiry=expiry_valid)
|
||||
self.assertTrue(lic.status.status().is_valid)
|
||||
lic2 = License.objects.create(key=generate_id())
|
||||
lic2 = License.objects.create(key=generate_id(), expiry=expiry_valid)
|
||||
self.assertTrue(lic2.status.status().is_valid)
|
||||
total = LicenseKey.get_total()
|
||||
self.assertEqual(total.internal_users, 200)
|
||||
@ -232,7 +233,9 @@ class TestEnterpriseLicense(TestCase):
|
||||
)
|
||||
def test_expiry_expired(self):
|
||||
"""Check license verification"""
|
||||
License.objects.create(key=generate_id())
|
||||
User.objects.all().delete()
|
||||
License.objects.all().delete()
|
||||
License.objects.create(key=generate_id(), expiry=expiry_expired)
|
||||
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
|
||||
|
||||
@patch(
|
||||
|
@ -59,7 +59,7 @@ def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | Non
|
||||
session = request_or_session.session
|
||||
if isinstance(request_or_session, AuthenticatedSession):
|
||||
SessionStore = _session_engine.SessionStore
|
||||
session = SessionStore(request_or_session.session_key)
|
||||
session = SessionStore(request_or_session.session.session_key)
|
||||
return session.get(SESSION_LOGIN_EVENT, None)
|
||||
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
{% endblock %}
|
||||
<link rel="stylesheet" type="text/css" href="{% static 'dist/sfe/bootstrap.min.css' %}">
|
||||
<meta name="sentry-trace" content="{{ sentry_trace }}" />
|
||||
<link rel="prefetch" href="{{ flow_background_url }}" />
|
||||
{% include "base/header_js.html" %}
|
||||
<style>
|
||||
html,
|
||||
@ -22,7 +23,7 @@
|
||||
height: 100%;
|
||||
}
|
||||
body {
|
||||
background-image: url("{{ flow.background_url }}");
|
||||
background-image: url("{{ flow_background_url }}");
|
||||
background-repeat: no-repeat;
|
||||
background-size: cover;
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
{% block head_before %}
|
||||
{{ block.super }}
|
||||
<link rel="prefetch" href="{{ flow.background_url }}" />
|
||||
<link rel="prefetch" href="{{ flow_background_url }}" />
|
||||
{% if flow.compatibility_mode and not inspector %}
|
||||
<script>ShadyDOM = { force: !navigator.webdriver };</script>
|
||||
{% endif %}
|
||||
@ -21,7 +21,7 @@ window.authentik.flow = {
|
||||
<script src="{% versioned_script 'dist/flow/FlowInterface-%v.js' %}" type="module"></script>
|
||||
<style>
|
||||
:root {
|
||||
--ak-flow-background: url("{{ flow.background_url }}");
|
||||
--ak-flow-background: url("{{ flow_background_url }}");
|
||||
}
|
||||
</style>
|
||||
{% endblock %}
|
||||
|
@ -48,6 +48,7 @@ class TestFlowInspector(APITestCase):
|
||||
"allow_show_password": False,
|
||||
"captcha_stage": None,
|
||||
"component": "ak-stage-identification",
|
||||
"enable_remember_me": False,
|
||||
"flow_info": {
|
||||
"background": "/static/dist/assets/images/flow_background.jpg",
|
||||
"cancel_url": reverse("authentik_flows:cancel"),
|
||||
|
@ -15,13 +15,14 @@ class FlowInterfaceView(InterfaceView):
|
||||
|
||||
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||
flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug"))
|
||||
kwargs["flow"] = flow
|
||||
if (
|
||||
not self.request.user.is_authenticated
|
||||
and flow.designation == FlowDesignation.AUTHENTICATION
|
||||
):
|
||||
self.request.session[SESSION_KEY_AUTH_STARTED] = True
|
||||
self.request.session.save()
|
||||
kwargs["flow"] = flow
|
||||
kwargs["flow_background_url"] = flow.background_url(self.request)
|
||||
kwargs["inspector"] = "inspector" in self.request.GET
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
|
@ -356,6 +356,17 @@ def redis_url(db: int) -> str:
|
||||
def django_db_config(config: ConfigLoader | None = None) -> dict:
|
||||
if not config:
|
||||
config = CONFIG
|
||||
|
||||
pool_options = False
|
||||
use_pool = config.get_bool("postgresql.use_pool", False)
|
||||
if use_pool:
|
||||
pool_options = config.get_dict_from_b64_json("postgresql.pool_options", True)
|
||||
if not pool_options:
|
||||
pool_options = True
|
||||
# FIXME: Temporarily force pool to be deactivated.
|
||||
# See https://github.com/goauthentik/authentik/issues/14320
|
||||
pool_options = False
|
||||
|
||||
db = {
|
||||
"default": {
|
||||
"ENGINE": "authentik.root.db",
|
||||
@ -369,6 +380,7 @@ def django_db_config(config: ConfigLoader | None = None) -> dict:
|
||||
"sslrootcert": config.get("postgresql.sslrootcert"),
|
||||
"sslcert": config.get("postgresql.sslcert"),
|
||||
"sslkey": config.get("postgresql.sslkey"),
|
||||
"pool": pool_options,
|
||||
},
|
||||
"CONN_MAX_AGE": config.get_optional_int("postgresql.conn_max_age", 0),
|
||||
"CONN_HEALTH_CHECKS": config.get_bool("postgresql.conn_health_checks", False),
|
||||
|
@ -21,6 +21,7 @@ postgresql:
|
||||
user: authentik
|
||||
port: 5432
|
||||
password: "env://POSTGRES_PASSWORD"
|
||||
use_pool: False
|
||||
test:
|
||||
name: test_authentik
|
||||
default_schema: public
|
||||
|
@ -18,7 +18,7 @@ from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event
|
||||
from authentik.lib.expression.exceptions import ControlFlowException
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
@ -203,9 +203,7 @@ class BaseEvaluator:
|
||||
provider = OAuth2Provider.objects.get(name=provider)
|
||||
session = None
|
||||
if hasattr(request, "session") and request.session.session_key:
|
||||
session = AuthenticatedSession.objects.filter(
|
||||
session_key=request.session.session_key
|
||||
).first()
|
||||
session = request.session["authenticatedsession"]
|
||||
access_token = AccessToken(
|
||||
provider=provider,
|
||||
user=user,
|
||||
|
@ -217,6 +217,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "foo",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -267,6 +268,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "foo",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -285,6 +287,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "bar",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -333,6 +336,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "foo",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -351,6 +355,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "bar",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -394,6 +399,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "foo",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -412,6 +418,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "bar",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -451,6 +458,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "foo",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "foo",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -469,6 +477,7 @@ class TestConfig(TestCase):
|
||||
"HOST": "bar",
|
||||
"NAME": "foo",
|
||||
"OPTIONS": {
|
||||
"pool": False,
|
||||
"sslcert": "bar",
|
||||
"sslkey": "foo",
|
||||
"sslmode": "foo",
|
||||
@ -484,3 +493,89 @@ class TestConfig(TestCase):
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# FIXME: Temporarily force pool to be deactivated.
|
||||
# See https://github.com/goauthentik/authentik/issues/14320
|
||||
# def test_db_pool(self):
|
||||
# """Test DB Config with pool"""
|
||||
# config = ConfigLoader()
|
||||
# config.set("postgresql.host", "foo")
|
||||
# config.set("postgresql.name", "foo")
|
||||
# config.set("postgresql.user", "foo")
|
||||
# config.set("postgresql.password", "foo")
|
||||
# config.set("postgresql.port", "foo")
|
||||
# config.set("postgresql.test.name", "foo")
|
||||
# config.set("postgresql.use_pool", True)
|
||||
# conf = django_db_config(config)
|
||||
# self.assertEqual(
|
||||
# conf,
|
||||
# {
|
||||
# "default": {
|
||||
# "ENGINE": "authentik.root.db",
|
||||
# "HOST": "foo",
|
||||
# "NAME": "foo",
|
||||
# "OPTIONS": {
|
||||
# "pool": True,
|
||||
# "sslcert": None,
|
||||
# "sslkey": None,
|
||||
# "sslmode": None,
|
||||
# "sslrootcert": None,
|
||||
# },
|
||||
# "PASSWORD": "foo",
|
||||
# "PORT": "foo",
|
||||
# "TEST": {"NAME": "foo"},
|
||||
# "USER": "foo",
|
||||
# "CONN_MAX_AGE": 0,
|
||||
# "CONN_HEALTH_CHECKS": False,
|
||||
# "DISABLE_SERVER_SIDE_CURSORS": False,
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
|
||||
# def test_db_pool_options(self):
|
||||
# """Test DB Config with pool"""
|
||||
# config = ConfigLoader()
|
||||
# config.set("postgresql.host", "foo")
|
||||
# config.set("postgresql.name", "foo")
|
||||
# config.set("postgresql.user", "foo")
|
||||
# config.set("postgresql.password", "foo")
|
||||
# config.set("postgresql.port", "foo")
|
||||
# config.set("postgresql.test.name", "foo")
|
||||
# config.set("postgresql.use_pool", True)
|
||||
# config.set(
|
||||
# "postgresql.pool_options",
|
||||
# base64.b64encode(
|
||||
# dumps(
|
||||
# {
|
||||
# "max_size": 15,
|
||||
# }
|
||||
# ).encode()
|
||||
# ).decode(),
|
||||
# )
|
||||
# conf = django_db_config(config)
|
||||
# self.assertEqual(
|
||||
# conf,
|
||||
# {
|
||||
# "default": {
|
||||
# "ENGINE": "authentik.root.db",
|
||||
# "HOST": "foo",
|
||||
# "NAME": "foo",
|
||||
# "OPTIONS": {
|
||||
# "pool": {
|
||||
# "max_size": 15,
|
||||
# },
|
||||
# "sslcert": None,
|
||||
# "sslkey": None,
|
||||
# "sslmode": None,
|
||||
# "sslrootcert": None,
|
||||
# },
|
||||
# "PASSWORD": "foo",
|
||||
# "PORT": "foo",
|
||||
# "TEST": {"NAME": "foo"},
|
||||
# "USER": "foo",
|
||||
# "CONN_MAX_AGE": 0,
|
||||
# "CONN_HEALTH_CHECKS": False,
|
||||
# "DISABLE_SERVER_SIDE_CURSORS": False,
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
|
@ -74,6 +74,8 @@ class OutpostConfig:
|
||||
kubernetes_ingress_annotations: dict[str, str] = field(default_factory=dict)
|
||||
kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls")
|
||||
kubernetes_ingress_class_name: str | None = field(default=None)
|
||||
kubernetes_httproute_annotations: dict[str, str] = field(default_factory=dict)
|
||||
kubernetes_httproute_parent_refs: list[dict[str, str]] = field(default_factory=list)
|
||||
kubernetes_service_type: str = field(default="ClusterIP")
|
||||
kubernetes_disabled_components: list[str] = field(default_factory=list)
|
||||
kubernetes_image_pull_secrets: list[str] = field(default_factory=list)
|
||||
|
@ -1,4 +1,8 @@
|
||||
"""authentik policies app config"""
|
||||
"""Authentik policies app config
|
||||
|
||||
Every system policy should be its own Django app under the `policies` app.
|
||||
For example: The 'dummy' policy is available at `authentik.policies.dummy`.
|
||||
"""
|
||||
|
||||
from prometheus_client import Gauge, Histogram
|
||||
|
||||
|
@ -66,7 +66,9 @@ class GeoIPPolicy(Policy):
|
||||
if not static_results and not dynamic_results:
|
||||
return PolicyResult(True)
|
||||
|
||||
passing = any(r.passing for r in static_results) and all(r.passing for r in dynamic_results)
|
||||
static_passing = any(r.passing for r in static_results) if static_results else True
|
||||
dynamic_passing = all(r.passing for r in dynamic_results)
|
||||
passing = static_passing and dynamic_passing
|
||||
messages = chain(
|
||||
*[r.messages for r in static_results], *[r.messages for r in dynamic_results]
|
||||
)
|
||||
@ -113,13 +115,19 @@ class GeoIPPolicy(Policy):
|
||||
to previous authentication requests"""
|
||||
# Get previous login event and GeoIP data
|
||||
previous_logins = Event.objects.filter(
|
||||
action=EventAction.LOGIN, user__pk=request.user.pk, context__geo__isnull=False
|
||||
action=EventAction.LOGIN,
|
||||
user__pk=request.user.pk, # context__geo__isnull=False
|
||||
).order_by("-created")[: self.history_login_count]
|
||||
_now = now()
|
||||
geoip_data: GeoIPDict | None = request.context.get("geoip")
|
||||
if not geoip_data:
|
||||
return PolicyResult(False)
|
||||
if not previous_logins.exists():
|
||||
return PolicyResult(True)
|
||||
result = False
|
||||
for previous_login in previous_logins:
|
||||
if "geo" not in previous_login.context:
|
||||
continue
|
||||
previous_login_geoip: GeoIPDict = previous_login.context["geo"]
|
||||
|
||||
# Figure out distance
|
||||
@ -142,7 +150,8 @@ class GeoIPPolicy(Policy):
|
||||
(MAX_DISTANCE_HOUR_KM * rel_time_hours) + self.distance_tolerance_km
|
||||
):
|
||||
return PolicyResult(False, _("Distance is further than possible."))
|
||||
return PolicyResult(True)
|
||||
result = True
|
||||
return PolicyResult(result)
|
||||
|
||||
class Meta(Policy.PolicyMeta):
|
||||
verbose_name = _("GeoIP Policy")
|
||||
|
@ -163,7 +163,7 @@ class TestGeoIPPolicy(TestCase):
|
||||
result: PolicyResult = policy.passes(self.request)
|
||||
self.assertFalse(result.passing)
|
||||
|
||||
def test_history_impossible_travel(self):
|
||||
def test_history_impossible_travel_failing(self):
|
||||
"""Test history checks"""
|
||||
Event.objects.create(
|
||||
action=EventAction.LOGIN,
|
||||
@ -181,6 +181,24 @@ class TestGeoIPPolicy(TestCase):
|
||||
result: PolicyResult = policy.passes(self.request)
|
||||
self.assertFalse(result.passing)
|
||||
|
||||
def test_history_impossible_travel_passing(self):
|
||||
"""Test history checks"""
|
||||
Event.objects.create(
|
||||
action=EventAction.LOGIN,
|
||||
user=get_user(self.user),
|
||||
context={
|
||||
# Random location in Canada
|
||||
"geo": {"lat": 55.868351, "long": -104.441011},
|
||||
},
|
||||
)
|
||||
# Same location
|
||||
self.request.context["geoip"] = {"lat": 55.868351, "long": -104.441011}
|
||||
|
||||
policy = GeoIPPolicy.objects.create(check_impossible_travel=True)
|
||||
|
||||
result: PolicyResult = policy.passes(self.request)
|
||||
self.assertTrue(result.passing)
|
||||
|
||||
def test_history_no_geoip(self):
|
||||
"""Test history checks (previous login with no geoip data)"""
|
||||
Event.objects.create(
|
||||
@ -195,3 +213,18 @@ class TestGeoIPPolicy(TestCase):
|
||||
|
||||
result: PolicyResult = policy.passes(self.request)
|
||||
self.assertFalse(result.passing)
|
||||
|
||||
def test_impossible_travel_no_geoip(self):
|
||||
"""Test impossible travel checks (previous login with no geoip data)"""
|
||||
Event.objects.create(
|
||||
action=EventAction.LOGIN,
|
||||
user=get_user(self.user),
|
||||
context={},
|
||||
)
|
||||
# Random location in Poland
|
||||
self.request.context["geoip"] = {"lat": 50.950613, "long": 20.363679}
|
||||
|
||||
policy = GeoIPPolicy.objects.create(check_impossible_travel=True)
|
||||
|
||||
result: PolicyResult = policy.passes(self.request)
|
||||
self.assertFalse(result.passing)
|
||||
|
@ -52,6 +52,13 @@ class PolicyBindingModel(models.Model):
|
||||
return ["policy", "user", "group"]
|
||||
|
||||
|
||||
class BoundPolicyQuerySet(models.QuerySet):
|
||||
"""QuerySet for filtering enabled bindings for a Policy type"""
|
||||
|
||||
def for_policy(self, policy: "Policy"):
|
||||
return self.filter(policy__in=policy._default_manager.all()).filter(enabled=True)
|
||||
|
||||
|
||||
class PolicyBinding(SerializerModel):
|
||||
"""Relationship between a Policy and a PolicyBindingModel."""
|
||||
|
||||
@ -148,6 +155,9 @@ class PolicyBinding(SerializerModel):
|
||||
return f"Binding - #{self.order} to {suffix}"
|
||||
return ""
|
||||
|
||||
objects = models.Manager()
|
||||
in_use = BoundPolicyQuerySet.as_manager()
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Policy Binding")
|
||||
verbose_name_plural = _("Policy Bindings")
|
||||
|
@ -2,4 +2,6 @@
|
||||
|
||||
from authentik.policies.password.api import PasswordPolicyViewSet
|
||||
|
||||
api_urlpatterns = [("policies/password", PasswordPolicyViewSet)]
|
||||
api_urlpatterns = [
|
||||
("policies/password", PasswordPolicyViewSet),
|
||||
]
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from django.contrib.auth.signals import user_logged_in
|
||||
from django.db import transaction
|
||||
from django.db.models import F
|
||||
from django.dispatch import receiver
|
||||
from django.http import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
@ -13,20 +12,29 @@ from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||
from authentik.policies.reputation.models import Reputation, reputation_expiry
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.identification.signals import identification_failed
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def clamp(value, min, max):
|
||||
return sorted([min, value, max])[1]
|
||||
|
||||
|
||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
"""Update score for IP and User"""
|
||||
remote_ip = ClientIPMiddleware.get_client_ip(request)
|
||||
tenant = get_current_tenant()
|
||||
new_score = clamp(amount, tenant.reputation_lower_limit, tenant.reputation_upper_limit)
|
||||
|
||||
with transaction.atomic():
|
||||
reputation, created = Reputation.objects.select_for_update().get_or_create(
|
||||
ip=remote_ip,
|
||||
identifier=identifier,
|
||||
defaults={
|
||||
"score": amount,
|
||||
"score": clamp(
|
||||
amount, tenant.reputation_lower_limit, tenant.reputation_upper_limit
|
||||
),
|
||||
"ip_geo_data": GEOIP_CONTEXT_PROCESSOR.city_dict(remote_ip) or {},
|
||||
"ip_asn_data": ASN_CONTEXT_PROCESSOR.asn_dict(remote_ip) or {},
|
||||
"expires": reputation_expiry(),
|
||||
@ -34,9 +42,15 @@ def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
)
|
||||
|
||||
if not created:
|
||||
reputation.score = F("score") + amount
|
||||
new_score = clamp(
|
||||
reputation.score + amount,
|
||||
tenant.reputation_lower_limit,
|
||||
tenant.reputation_upper_limit,
|
||||
)
|
||||
reputation.score = new_score
|
||||
reputation.save()
|
||||
LOGGER.info("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
|
||||
|
||||
LOGGER.info("Updated score", amount=new_score, for_user=identifier, for_ip=remote_ip)
|
||||
|
||||
|
||||
@receiver(login_failed)
|
||||
|
@ -6,9 +6,11 @@ from authentik.core.models import User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.reputation.api import ReputationPolicySerializer
|
||||
from authentik.policies.reputation.models import Reputation, ReputationPolicy
|
||||
from authentik.policies.reputation.signals import update_score
|
||||
from authentik.policies.types import PolicyRequest
|
||||
from authentik.stages.password import BACKEND_INBUILT
|
||||
from authentik.stages.password.stage import authenticate
|
||||
from authentik.tenants.models import DEFAULT_REPUTATION_LOWER_LIMIT, DEFAULT_REPUTATION_UPPER_LIMIT
|
||||
|
||||
|
||||
class TestReputationPolicy(TestCase):
|
||||
@ -17,36 +19,48 @@ class TestReputationPolicy(TestCase):
|
||||
def setUp(self):
|
||||
self.request_factory = RequestFactory()
|
||||
self.request = self.request_factory.get("/")
|
||||
self.test_ip = "127.0.0.1"
|
||||
self.test_username = "test"
|
||||
self.ip = "127.0.0.1"
|
||||
self.username = "username"
|
||||
self.password = generate_id()
|
||||
# We need a user for the one-to-one in userreputation
|
||||
self.user = User.objects.create(username=self.test_username)
|
||||
self.user = User.objects.create(username=self.username)
|
||||
self.user.set_password(self.password)
|
||||
self.backends = [BACKEND_INBUILT]
|
||||
|
||||
def test_ip_reputation(self):
|
||||
"""test IP reputation"""
|
||||
# Trigger negative reputation
|
||||
authenticate(
|
||||
self.request, self.backends, username=self.test_username, password=self.test_username
|
||||
)
|
||||
self.assertEqual(Reputation.objects.get(ip=self.test_ip).score, -1)
|
||||
authenticate(self.request, self.backends, username=self.username, password=self.username)
|
||||
self.assertEqual(Reputation.objects.get(ip=self.ip).score, -1)
|
||||
|
||||
def test_user_reputation(self):
|
||||
"""test User reputation"""
|
||||
# Trigger negative reputation
|
||||
authenticate(
|
||||
self.request, self.backends, username=self.test_username, password=self.test_username
|
||||
)
|
||||
self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, -1)
|
||||
authenticate(self.request, self.backends, username=self.username, password=self.username)
|
||||
self.assertEqual(Reputation.objects.get(identifier=self.username).score, -1)
|
||||
|
||||
def test_update_reputation(self):
|
||||
"""test reputation update"""
|
||||
Reputation.objects.create(identifier=self.test_username, ip=self.test_ip, score=43)
|
||||
Reputation.objects.create(identifier=self.username, ip=self.ip, score=4)
|
||||
# Trigger negative reputation
|
||||
authenticate(
|
||||
self.request, self.backends, username=self.test_username, password=self.test_username
|
||||
authenticate(self.request, self.backends, username=self.username, password=self.username)
|
||||
self.assertEqual(Reputation.objects.get(identifier=self.username).score, 3)
|
||||
|
||||
def test_reputation_lower_limit(self):
|
||||
"""test reputation lower limit"""
|
||||
Reputation.objects.create(identifier=self.username, ip=self.ip)
|
||||
update_score(self.request, identifier=self.username, amount=-1000)
|
||||
self.assertEqual(
|
||||
Reputation.objects.get(identifier=self.username).score, DEFAULT_REPUTATION_LOWER_LIMIT
|
||||
)
|
||||
|
||||
def test_reputation_upper_limit(self):
|
||||
"""test reputation upper limit"""
|
||||
Reputation.objects.create(identifier=self.username, ip=self.ip)
|
||||
update_score(self.request, identifier=self.username, amount=1000)
|
||||
self.assertEqual(
|
||||
Reputation.objects.get(identifier=self.username).score, DEFAULT_REPUTATION_UPPER_LIMIT
|
||||
)
|
||||
self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, 42)
|
||||
|
||||
def test_policy(self):
|
||||
"""Test Policy"""
|
||||
|
@ -126,7 +126,7 @@ class IDToken:
|
||||
id_token.iat = int(now.timestamp())
|
||||
id_token.auth_time = int(token.auth_time.timestamp())
|
||||
if token.session:
|
||||
id_token.sid = hash_session_key(token.session.session_key)
|
||||
id_token.sid = hash_session_key(token.session.session.session_key)
|
||||
|
||||
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
|
||||
auth_event = get_login_event(token.session)
|
||||
|
116
authentik/providers/oauth2/migrations/0028_migrate_session.py
Normal file
116
authentik/providers/oauth2/migrations/0028_migrate_session.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Generated by Django 5.0.11 on 2025-01-27 13:00
|
||||
|
||||
from django.db import migrations
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
from functools import partial
|
||||
|
||||
|
||||
def migrate_sessions(apps, schema_editor, model):
|
||||
Model = apps.get_model("authentik_providers_oauth2", model)
|
||||
AuthenticatedSession = apps.get_model("authentik_core", "AuthenticatedSession")
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
for obj in Model.objects.using(db_alias).all():
|
||||
if not obj.old_session:
|
||||
continue
|
||||
obj.session = (
|
||||
AuthenticatedSession.objects.using(db_alias)
|
||||
.filter(session__session_key=obj.old_session.session_key)
|
||||
.first()
|
||||
)
|
||||
if obj.session:
|
||||
obj.save()
|
||||
else:
|
||||
obj.delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_oauth2", "0027_accesstoken_authentik_p_expires_9f24a5_idx_and_more"),
|
||||
("authentik_core", "0046_session_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RenameField(
|
||||
model_name="accesstoken",
|
||||
old_name="session",
|
||||
new_name="old_session",
|
||||
),
|
||||
migrations.RenameField(
|
||||
model_name="authorizationcode",
|
||||
old_name="session",
|
||||
new_name="old_session",
|
||||
),
|
||||
migrations.RenameField(
|
||||
model_name="devicetoken",
|
||||
old_name="session",
|
||||
new_name="old_session",
|
||||
),
|
||||
migrations.RenameField(
|
||||
model_name="refreshtoken",
|
||||
old_name="session",
|
||||
new_name="old_session",
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="accesstoken",
|
||||
name="session",
|
||||
field=models.ForeignKey(
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="authorizationcode",
|
||||
name="session",
|
||||
field=models.ForeignKey(
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="devicetoken",
|
||||
name="session",
|
||||
field=models.ForeignKey(
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_DEFAULT,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="refreshtoken",
|
||||
name="session",
|
||||
field=models.ForeignKey(
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_DEFAULT,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
migrations.RunPython(code=partial(migrate_sessions, model="AccessToken")),
|
||||
migrations.RunPython(code=partial(migrate_sessions, model="AuthorizationCode")),
|
||||
migrations.RunPython(code=partial(migrate_sessions, model="DeviceToken")),
|
||||
migrations.RunPython(code=partial(migrate_sessions, model="RefreshToken")),
|
||||
migrations.RemoveField(
|
||||
model_name="accesstoken",
|
||||
name="old_session",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="authorizationcode",
|
||||
name="old_session",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="devicetoken",
|
||||
name="old_session",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="refreshtoken",
|
||||
name="old_session",
|
||||
),
|
||||
]
|
@ -1,18 +1,30 @@
|
||||
from django.contrib.auth.signals import user_logged_out
|
||||
from django.db.models.signals import post_save
|
||||
from django.db.models.signals import post_save, pre_delete
|
||||
from django.dispatch import receiver
|
||||
from django.http import HttpRequest
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken
|
||||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User, **_):
|
||||
"""Revoke access tokens upon user logout"""
|
||||
def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_):
|
||||
"""Revoke tokens upon user logout"""
|
||||
if not request.session or not request.session.session_key:
|
||||
return
|
||||
AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete()
|
||||
AccessToken.objects.filter(
|
||||
user=user,
|
||||
session__session__session_key=request.session.session_key,
|
||||
).delete()
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_):
|
||||
"""Revoke tokens upon user logout"""
|
||||
AccessToken.objects.filter(
|
||||
user=instance.user,
|
||||
session__session__session_key=instance.session.session_key,
|
||||
).delete()
|
||||
|
||||
|
||||
@receiver(post_save, sender=User)
|
||||
@ -20,6 +32,6 @@ def user_deactivated(sender, instance: User, **_):
|
||||
"""Remove user tokens when deactivated"""
|
||||
if instance.is_active:
|
||||
return
|
||||
AccessToken.objects.filter(session__user=instance).delete()
|
||||
RefreshToken.objects.filter(session__user=instance).delete()
|
||||
DeviceToken.objects.filter(session__user=instance).delete()
|
||||
AccessToken.objects.filter(user=instance).delete()
|
||||
RefreshToken.objects.filter(user=instance).delete()
|
||||
DeviceToken.objects.filter(user=instance).delete()
|
||||
|
@ -7,12 +7,13 @@ from dataclasses import asdict
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.models import Application, AuthenticatedSession, Session
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
ClientTypes,
|
||||
DeviceToken,
|
||||
IDToken,
|
||||
OAuth2Provider,
|
||||
RedirectURI,
|
||||
@ -20,6 +21,7 @@ from authentik.providers.oauth2.models import (
|
||||
RefreshToken,
|
||||
)
|
||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
|
||||
class TesOAuth2Revoke(OAuthTestCase):
|
||||
@ -135,3 +137,86 @@ class TesOAuth2Revoke(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
|
||||
def test_revoke_logout(self):
|
||||
"""Test revoke on logout"""
|
||||
self.client.force_login(self.user)
|
||||
AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
session=self.client.session["authenticatedsession"],
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.client.logout()
|
||||
self.assertEqual(AccessToken.objects.all().count(), 0)
|
||||
|
||||
def test_revoke_session_delete(self):
|
||||
"""Test revoke on logout"""
|
||||
session = AuthenticatedSession.objects.create(
|
||||
session=Session.objects.create(
|
||||
session_key=generate_id(),
|
||||
last_ip=ClientIPMiddleware.default_ip,
|
||||
),
|
||||
user=self.user,
|
||||
)
|
||||
AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
session=session,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
session.delete()
|
||||
self.assertEqual(AccessToken.objects.all().count(), 0)
|
||||
|
||||
def test_revoke_user_deactivated(self):
|
||||
"""Test revoke on logout"""
|
||||
AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
RefreshToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
DeviceToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
_scope="openid user profile",
|
||||
)
|
||||
|
||||
self.user.is_active = False
|
||||
self.user.save()
|
||||
|
||||
self.assertEqual(AccessToken.objects.all().count(), 0)
|
||||
self.assertEqual(RefreshToken.objects.all().count(), 0)
|
||||
self.assertEqual(DeviceToken.objects.all().count(), 0)
|
||||
|
@ -15,7 +15,7 @@ from django.utils import timezone
|
||||
from django.utils.translation import gettext as _
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
from authentik.core.models import Application
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.signals import get_login_event
|
||||
from authentik.flows.challenge import (
|
||||
@ -316,9 +316,7 @@ class OAuthAuthorizationParams:
|
||||
expires=now + timedelta_from_string(self.provider.access_code_validity),
|
||||
scope=self.scope,
|
||||
nonce=self.nonce,
|
||||
session=AuthenticatedSession.objects.filter(
|
||||
session_key=request.session.session_key
|
||||
).first(),
|
||||
session=request.session["authenticatedsession"],
|
||||
)
|
||||
|
||||
if self.code_challenge and self.code_challenge_method:
|
||||
@ -615,9 +613,7 @@ class OAuthFulfillmentStage(StageView):
|
||||
expires=access_token_expiry,
|
||||
provider=self.provider,
|
||||
auth_time=auth_event.created if auth_event else now,
|
||||
session=AuthenticatedSession.objects.filter(
|
||||
session_key=self.request.session.session_key
|
||||
).first(),
|
||||
session=self.request.session["authenticatedsession"],
|
||||
)
|
||||
|
||||
id_token = IDToken.new(self.provider, token, self.request)
|
||||
|
234
authentik/providers/proxy/controllers/k8s/httproute.py
Normal file
234
authentik/providers/proxy/controllers/k8s/httproute.py
Normal file
@ -0,0 +1,234 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dacite.core import from_dict
|
||||
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi, V1ObjectMeta
|
||||
|
||||
from authentik.outposts.controllers.base import FIELD_MANAGER
|
||||
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler
|
||||
from authentik.outposts.controllers.k8s.triggers import NeedsUpdate
|
||||
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||
from authentik.providers.proxy.models import ProxyMode, ProxyProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RouteBackendRef:
|
||||
name: str
|
||||
port: int
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RouteSpecParentRefs:
|
||||
name: str
|
||||
sectionName: str | None = None
|
||||
port: int | None = None
|
||||
namespace: str | None = None
|
||||
kind: str = "Gateway"
|
||||
group: str = "gateway.networking.k8s.io"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HTTPRouteSpecRuleMatchPath:
|
||||
type: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HTTPRouteSpecRuleMatchHeader:
|
||||
name: str
|
||||
value: str
|
||||
type: str = "Exact"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HTTPRouteSpecRuleMatch:
|
||||
path: HTTPRouteSpecRuleMatchPath
|
||||
headers: list[HTTPRouteSpecRuleMatchHeader]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HTTPRouteSpecRule:
|
||||
backendRefs: list[RouteBackendRef]
|
||||
matches: list[HTTPRouteSpecRuleMatch]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HTTPRouteSpec:
|
||||
parentRefs: list[RouteSpecParentRefs]
|
||||
hostnames: list[str]
|
||||
rules: list[HTTPRouteSpecRule]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HTTPRouteMetadata:
|
||||
name: str
|
||||
namespace: str
|
||||
annotations: dict = field(default_factory=dict)
|
||||
labels: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HTTPRoute:
|
||||
apiVersion: str
|
||||
kind: str
|
||||
metadata: HTTPRouteMetadata
|
||||
spec: HTTPRouteSpec
|
||||
|
||||
|
||||
class HTTPRouteReconciler(KubernetesObjectReconciler):
|
||||
"""Kubernetes Gateway API HTTPRoute Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
super().__init__(controller)
|
||||
self.api_ex = ApiextensionsV1Api(controller.client)
|
||||
self.api = CustomObjectsApi(controller.client)
|
||||
self.crd_group = "gateway.networking.k8s.io"
|
||||
self.crd_version = "v1"
|
||||
self.crd_plural = "httproutes"
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "httproute"
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
if not self.crd_exists():
|
||||
self.logger.debug("CRD doesn't exist")
|
||||
return True
|
||||
if not self.controller.outpost.config.kubernetes_httproute_parent_refs:
|
||||
self.logger.debug("HTTPRoute parentRefs not set.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def crd_exists(self) -> bool:
|
||||
"""Check if the Gateway API resources exists"""
|
||||
return bool(
|
||||
len(
|
||||
self.api_ex.list_custom_resource_definition(
|
||||
field_selector=f"metadata.name={self.crd_plural}.{self.crd_group}"
|
||||
).items
|
||||
)
|
||||
)
|
||||
|
||||
def reconcile(self, current: HTTPRoute, reference: HTTPRoute):
|
||||
super().reconcile(current, reference)
|
||||
if current.metadata.annotations != reference.metadata.annotations:
|
||||
raise NeedsUpdate()
|
||||
if current.spec.parentRefs != reference.spec.parentRefs:
|
||||
raise NeedsUpdate()
|
||||
if current.spec.hostnames != reference.spec.hostnames:
|
||||
raise NeedsUpdate()
|
||||
if current.spec.rules != reference.spec.rules:
|
||||
raise NeedsUpdate()
|
||||
|
||||
def get_object_meta(self, **kwargs) -> V1ObjectMeta:
|
||||
return super().get_object_meta(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_reference_object(self) -> HTTPRoute:
|
||||
hostnames = []
|
||||
rules = []
|
||||
|
||||
for proxy_provider in ProxyProvider.objects.filter(outpost__in=[self.controller.outpost]):
|
||||
proxy_provider: ProxyProvider
|
||||
external_host_name = urlparse(proxy_provider.external_host)
|
||||
if proxy_provider.mode in [ProxyMode.FORWARD_SINGLE, ProxyMode.FORWARD_DOMAIN]:
|
||||
rule = HTTPRouteSpecRule(
|
||||
backendRefs=[RouteBackendRef(name=self.name, port=9000)],
|
||||
matches=[
|
||||
HTTPRouteSpecRuleMatch(
|
||||
headers=[
|
||||
HTTPRouteSpecRuleMatchHeader(
|
||||
name="Host",
|
||||
value=external_host_name.hostname,
|
||||
)
|
||||
],
|
||||
path=HTTPRouteSpecRuleMatchPath(
|
||||
type="PathPrefix", value="/outpost.goauthentik.io"
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
rule = HTTPRouteSpecRule(
|
||||
backendRefs=[RouteBackendRef(name=self.name, port=9000)],
|
||||
matches=[
|
||||
HTTPRouteSpecRuleMatch(
|
||||
headers=[
|
||||
HTTPRouteSpecRuleMatchHeader(
|
||||
name="Host",
|
||||
value=external_host_name.hostname,
|
||||
)
|
||||
],
|
||||
path=HTTPRouteSpecRuleMatchPath(type="PathPrefix", value="/"),
|
||||
)
|
||||
],
|
||||
)
|
||||
hostnames.append(external_host_name.hostname)
|
||||
rules.append(rule)
|
||||
|
||||
return HTTPRoute(
|
||||
apiVersion=f"{self.crd_group}/{self.crd_version}",
|
||||
kind="HTTPRoute",
|
||||
metadata=HTTPRouteMetadata(
|
||||
name=self.name,
|
||||
namespace=self.namespace,
|
||||
annotations=self.controller.outpost.config.kubernetes_httproute_annotations,
|
||||
labels=self.get_object_meta().labels,
|
||||
),
|
||||
spec=HTTPRouteSpec(
|
||||
parentRefs=[
|
||||
from_dict(RouteSpecParentRefs, spec)
|
||||
for spec in self.controller.outpost.config.kubernetes_httproute_parent_refs
|
||||
],
|
||||
hostnames=hostnames,
|
||||
rules=rules,
|
||||
),
|
||||
)
|
||||
|
||||
def create(self, reference: HTTPRoute):
|
||||
return self.api.create_namespaced_custom_object(
|
||||
group=self.crd_group,
|
||||
version=self.crd_version,
|
||||
plural=self.crd_plural,
|
||||
namespace=self.namespace,
|
||||
body=asdict(reference),
|
||||
field_manager=FIELD_MANAGER,
|
||||
)
|
||||
|
||||
def delete(self, reference: HTTPRoute):
|
||||
return self.api.delete_namespaced_custom_object(
|
||||
group=self.crd_group,
|
||||
version=self.crd_version,
|
||||
plural=self.crd_plural,
|
||||
namespace=self.namespace,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def retrieve(self) -> HTTPRoute:
|
||||
return from_dict(
|
||||
HTTPRoute,
|
||||
self.api.get_namespaced_custom_object(
|
||||
group=self.crd_group,
|
||||
version=self.crd_version,
|
||||
plural=self.crd_plural,
|
||||
namespace=self.namespace,
|
||||
name=self.name,
|
||||
),
|
||||
)
|
||||
|
||||
def update(self, current: HTTPRoute, reference: HTTPRoute):
|
||||
return self.api.patch_namespaced_custom_object(
|
||||
group=self.crd_group,
|
||||
version=self.crd_version,
|
||||
plural=self.crd_plural,
|
||||
namespace=self.namespace,
|
||||
name=self.name,
|
||||
body=asdict(reference),
|
||||
field_manager=FIELD_MANAGER,
|
||||
)
|
@ -3,6 +3,7 @@
|
||||
from authentik.outposts.controllers.base import DeploymentPort
|
||||
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||
from authentik.outposts.models import KubernetesServiceConnection, Outpost
|
||||
from authentik.providers.proxy.controllers.k8s.httproute import HTTPRouteReconciler
|
||||
from authentik.providers.proxy.controllers.k8s.ingress import IngressReconciler
|
||||
from authentik.providers.proxy.controllers.k8s.traefik import TraefikMiddlewareReconciler
|
||||
|
||||
@ -18,8 +19,10 @@ class ProxyKubernetesController(KubernetesController):
|
||||
DeploymentPort(9443, "https", "tcp"),
|
||||
]
|
||||
self.reconcilers[IngressReconciler.reconciler_name()] = IngressReconciler
|
||||
self.reconcilers[HTTPRouteReconciler.reconciler_name()] = HTTPRouteReconciler
|
||||
self.reconcilers[TraefikMiddlewareReconciler.reconciler_name()] = (
|
||||
TraefikMiddlewareReconciler
|
||||
)
|
||||
self.reconcile_order.append(IngressReconciler.reconciler_name())
|
||||
self.reconcile_order.append(HTTPRouteReconciler.reconciler_name())
|
||||
self.reconcile_order.append(TraefikMiddlewareReconciler.reconciler_name())
|
||||
|
@ -20,4 +20,4 @@ def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_):
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
proxy_on_logout.delay(instance.session_key)
|
||||
proxy_on_logout.delay(instance.session.session_key)
|
||||
|
60
authentik/providers/rac/migrations/0007_migrate_session.py
Normal file
60
authentik/providers/rac/migrations/0007_migrate_session.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Generated by Django 5.0.11 on 2025-01-27 12:59
|
||||
|
||||
from django.db import migrations
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def migrate_sessions(apps, schema_editor):
|
||||
ConnectionToken = apps.get_model("authentik_providers_rac", "ConnectionToken")
|
||||
AuthenticatedSession = apps.get_model("authentik_core", "AuthenticatedSession")
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
for token in ConnectionToken.objects.using(db_alias).all():
|
||||
token.session = (
|
||||
AuthenticatedSession.objects.using(db_alias)
|
||||
.filter(session_key=token.old_session.session_key)
|
||||
.first()
|
||||
)
|
||||
if token.session:
|
||||
token.save()
|
||||
else:
|
||||
token.delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_rac", "0006_connectiontoken_authentik_p_expires_91f148_idx_and_more"),
|
||||
("authentik_core", "0046_session_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RenameField(
|
||||
model_name="connectiontoken",
|
||||
old_name="session",
|
||||
new_name="old_session",
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="connectiontoken",
|
||||
name="session",
|
||||
field=models.ForeignKey(
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
migrations.RunPython(code=migrate_sessions),
|
||||
migrations.AlterField(
|
||||
model_name="connectiontoken",
|
||||
name="session",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="connectiontoken",
|
||||
name="old_session",
|
||||
),
|
||||
]
|
@ -8,7 +8,7 @@ from django.db.models.signals import post_delete, post_save, pre_delete
|
||||
from django.dispatch import receiver
|
||||
from django.http import HttpRequest
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
|
||||
from authentik.providers.rac.consumer_client import (
|
||||
RAC_CLIENT_GROUP_SESSION,
|
||||
@ -32,6 +32,18 @@ def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
|
||||
)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
RAC_CLIENT_GROUP_SESSION
|
||||
% {
|
||||
"session": instance.session.session_key,
|
||||
},
|
||||
{"type": "event.disconnect", "reason": "session_logout"},
|
||||
)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=ConnectionToken)
|
||||
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
|
||||
"""Disconnect session when connection token is deleted"""
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
from authentik.core.models import Application, AuthenticatedSession, Session
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.rac.models import (
|
||||
@ -36,13 +36,15 @@ class TestModels(TransactionTestCase):
|
||||
|
||||
def test_settings_merge(self):
|
||||
"""Test settings merge"""
|
||||
session = Session.objects.create(
|
||||
session_key=generate_id(),
|
||||
last_ip="255.255.255.255",
|
||||
)
|
||||
auth_session = AuthenticatedSession.objects.create(session=session, user=self.user)
|
||||
token = ConnectionToken.objects.create(
|
||||
provider=self.provider,
|
||||
endpoint=self.endpoint,
|
||||
session=AuthenticatedSession.objects.create(
|
||||
user=self.user,
|
||||
session_key=generate_id(),
|
||||
),
|
||||
session=auth_session,
|
||||
)
|
||||
path = f"/tmp/connection/{token.token}" # nosec
|
||||
self.assertEqual(
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""rac urls"""
|
||||
|
||||
from channels.auth import AuthMiddleware
|
||||
from channels.sessions import CookieMiddleware
|
||||
from django.urls import path
|
||||
|
||||
from authentik.outposts.channels import TokenOutpostMiddleware
|
||||
@ -12,7 +10,7 @@ from authentik.providers.rac.api.providers import RACProviderViewSet
|
||||
from authentik.providers.rac.consumer_client import RACClientConsumer
|
||||
from authentik.providers.rac.consumer_outpost import RACOutpostConsumer
|
||||
from authentik.providers.rac.views import RACInterface, RACStartView
|
||||
from authentik.root.asgi_middleware import SessionMiddleware
|
||||
from authentik.root.asgi_middleware import AuthMiddlewareStack
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
|
||||
urlpatterns = [
|
||||
@ -31,9 +29,7 @@ urlpatterns = [
|
||||
websocket_urlpatterns = [
|
||||
path(
|
||||
"ws/rac/<str:token>/",
|
||||
ChannelsLoggingMiddleware(
|
||||
CookieMiddleware(SessionMiddleware(AuthMiddleware(RACClientConsumer.as_asgi())))
|
||||
),
|
||||
ChannelsLoggingMiddleware(AuthMiddlewareStack(RACClientConsumer.as_asgi())),
|
||||
),
|
||||
path(
|
||||
"ws/outpost_rac/<str:channel>/",
|
||||
|
@ -8,7 +8,7 @@ from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.views.interface import InterfaceView
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.challenge import RedirectChallenge
|
||||
@ -113,9 +113,7 @@ class RACFinalStage(RedirectStage):
|
||||
provider=self.provider,
|
||||
endpoint=self.endpoint,
|
||||
settings=self.executor.plan.context.get("connection_settings", {}),
|
||||
session=AuthenticatedSession.objects.filter(
|
||||
session_key=self.request.session.session_key
|
||||
).first(),
|
||||
session=self.request.session["authenticatedsession"],
|
||||
expires=now() + timedelta_from_string(self.provider.connection_expiry),
|
||||
expiring=True,
|
||||
)
|
||||
|
@ -35,8 +35,8 @@ REQUEST_KEY_SAML_SIG_ALG = "SigAlg"
|
||||
REQUEST_KEY_SAML_RESPONSE = "SAMLResponse"
|
||||
REQUEST_KEY_RELAY_STATE = "RelayState"
|
||||
|
||||
SESSION_KEY_AUTH_N_REQUEST = "authentik/providers/saml/authn_request"
|
||||
SESSION_KEY_LOGOUT_REQUEST = "authentik/providers/saml/logout_request"
|
||||
PLAN_CONTEXT_SAML_AUTH_N_REQUEST = "authentik/providers/saml/authn_request"
|
||||
PLAN_CONTEXT_SAML_LOGOUT_REQUEST = "authentik/providers/saml/logout_request"
|
||||
|
||||
|
||||
# This View doesn't have a URL on purpose, as its called by the FlowExecutor
|
||||
@ -50,10 +50,11 @@ class SAMLFlowFinalView(ChallengeStageView):
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
application: Application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION]
|
||||
provider: SAMLProvider = get_object_or_404(SAMLProvider, pk=application.provider_id)
|
||||
if SESSION_KEY_AUTH_N_REQUEST not in self.request.session:
|
||||
if PLAN_CONTEXT_SAML_AUTH_N_REQUEST not in self.executor.plan.context:
|
||||
self.logger.warning("No AuthNRequest in context")
|
||||
return self.executor.stage_invalid()
|
||||
|
||||
auth_n_request: AuthNRequest = self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST)
|
||||
auth_n_request: AuthNRequest = self.executor.plan.context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST]
|
||||
try:
|
||||
response = AssertionProcessor(provider, request, auth_n_request).build_response()
|
||||
except SAMLException as exc:
|
||||
@ -106,6 +107,3 @@ class SAMLFlowFinalView(ChallengeStageView):
|
||||
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
|
||||
# We'll never get here since the challenge redirects to the SP
|
||||
return HttpResponseBadRequest()
|
||||
|
||||
def cleanup(self):
|
||||
self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST, None)
|
||||
|
@ -19,9 +19,9 @@ from authentik.providers.saml.exceptions import CannotHandleAssertion
|
||||
from authentik.providers.saml.models import SAMLProvider
|
||||
from authentik.providers.saml.processors.logout_request_parser import LogoutRequestParser
|
||||
from authentik.providers.saml.views.flows import (
|
||||
PLAN_CONTEXT_SAML_LOGOUT_REQUEST,
|
||||
REQUEST_KEY_RELAY_STATE,
|
||||
REQUEST_KEY_SAML_REQUEST,
|
||||
SESSION_KEY_LOGOUT_REQUEST,
|
||||
)
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -33,6 +33,10 @@ class SAMLSLOView(PolicyAccessView):
|
||||
|
||||
flow: Flow
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.plan_context = {}
|
||||
|
||||
def resolve_provider_application(self):
|
||||
self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"])
|
||||
self.provider: SAMLProvider = get_object_or_404(
|
||||
@ -59,6 +63,7 @@ class SAMLSLOView(PolicyAccessView):
|
||||
request,
|
||||
{
|
||||
PLAN_CONTEXT_APPLICATION: self.application,
|
||||
**self.plan_context,
|
||||
},
|
||||
)
|
||||
plan.append_stage(in_memory_stage(SessionEndStage))
|
||||
@ -83,7 +88,7 @@ class SAMLSLOBindingRedirectView(SAMLSLOView):
|
||||
self.request.GET[REQUEST_KEY_SAML_REQUEST],
|
||||
relay_state=self.request.GET.get(REQUEST_KEY_RELAY_STATE, None),
|
||||
)
|
||||
self.request.session[SESSION_KEY_LOGOUT_REQUEST] = logout_request
|
||||
self.plan_context[PLAN_CONTEXT_SAML_LOGOUT_REQUEST] = logout_request
|
||||
except CannotHandleAssertion as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
@ -111,7 +116,7 @@ class SAMLSLOBindingPOSTView(SAMLSLOView):
|
||||
payload[REQUEST_KEY_SAML_REQUEST],
|
||||
relay_state=payload.get(REQUEST_KEY_RELAY_STATE, None),
|
||||
)
|
||||
self.request.session[SESSION_KEY_LOGOUT_REQUEST] = logout_request
|
||||
self.plan_context[PLAN_CONTEXT_SAML_LOGOUT_REQUEST] = logout_request
|
||||
except CannotHandleAssertion as exc:
|
||||
LOGGER.info(str(exc))
|
||||
return bad_request_message(self.request, str(exc))
|
||||
|
@ -20,11 +20,11 @@ from authentik.providers.saml.exceptions import CannotHandleAssertion
|
||||
from authentik.providers.saml.models import SAMLBindings, SAMLProvider
|
||||
from authentik.providers.saml.processors.authn_request_parser import AuthNRequestParser
|
||||
from authentik.providers.saml.views.flows import (
|
||||
PLAN_CONTEXT_SAML_AUTH_N_REQUEST,
|
||||
REQUEST_KEY_RELAY_STATE,
|
||||
REQUEST_KEY_SAML_REQUEST,
|
||||
REQUEST_KEY_SAML_SIG_ALG,
|
||||
REQUEST_KEY_SAML_SIGNATURE,
|
||||
SESSION_KEY_AUTH_N_REQUEST,
|
||||
SAMLFlowFinalView,
|
||||
)
|
||||
from authentik.stages.consent.stage import (
|
||||
@ -39,6 +39,10 @@ class SAMLSSOView(BufferedPolicyAccessView):
|
||||
"""SAML SSO Base View, which plans a flow and injects our final stage.
|
||||
Calls get/post handler."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.plan_context = {}
|
||||
|
||||
def resolve_provider_application(self):
|
||||
self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"])
|
||||
self.provider: SAMLProvider = get_object_or_404(
|
||||
@ -68,6 +72,7 @@ class SAMLSSOView(BufferedPolicyAccessView):
|
||||
PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
|
||||
% {"application": self.application.name},
|
||||
PLAN_CONTEXT_CONSENT_PERMISSIONS: [],
|
||||
**self.plan_context,
|
||||
},
|
||||
)
|
||||
except FlowNonApplicableException:
|
||||
@ -103,7 +108,7 @@ class SAMLSSOBindingRedirectView(SAMLSSOView):
|
||||
self.request.GET.get(REQUEST_KEY_SAML_SIGNATURE),
|
||||
self.request.GET.get(REQUEST_KEY_SAML_SIG_ALG),
|
||||
)
|
||||
self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request
|
||||
self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request
|
||||
except CannotHandleAssertion as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
@ -137,7 +142,7 @@ class SAMLSSOBindingPOSTView(SAMLSSOView):
|
||||
payload[REQUEST_KEY_SAML_REQUEST],
|
||||
payload.get(REQUEST_KEY_RELAY_STATE),
|
||||
)
|
||||
self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request
|
||||
self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request
|
||||
except CannotHandleAssertion as exc:
|
||||
LOGGER.info(str(exc))
|
||||
return bad_request_message(self.request, str(exc))
|
||||
@ -151,4 +156,4 @@ class SAMLSSOBindingInitView(SAMLSSOView):
|
||||
"""Create SAML Response from scratch"""
|
||||
LOGGER.debug("No SAML Request, using IdP-initiated flow.")
|
||||
auth_n_request = AuthNRequestParser(self.provider).idp_initiated()
|
||||
self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request
|
||||
self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request
|
||||
|
41
authentik/rbac/api/initial_permissions.py
Normal file
41
authentik/rbac/api/initial_permissions.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""RBAC Initial Permissions"""
|
||||
|
||||
from rest_framework.serializers import ListSerializer
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.rbac.api.rbac import PermissionSerializer
|
||||
from authentik.rbac.models import InitialPermissions
|
||||
|
||||
|
||||
class InitialPermissionsSerializer(ModelSerializer):
|
||||
"""InitialPermissions serializer"""
|
||||
|
||||
permissions_obj = ListSerializer(
|
||||
child=PermissionSerializer(),
|
||||
read_only=True,
|
||||
source="permissions",
|
||||
required=False,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = InitialPermissions
|
||||
fields = [
|
||||
"pk",
|
||||
"name",
|
||||
"mode",
|
||||
"role",
|
||||
"permissions",
|
||||
"permissions_obj",
|
||||
]
|
||||
|
||||
|
||||
class InitialPermissionsViewSet(UsedByMixin, ModelViewSet):
|
||||
"""InitialPermissions viewset"""
|
||||
|
||||
queryset = InitialPermissions.objects.all()
|
||||
serializer_class = InitialPermissionsSerializer
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
filterset_fields = ["name"]
|
@ -99,6 +99,7 @@ class RBACPermissionViewSet(ReadOnlyModelViewSet):
|
||||
filterset_class = PermissionFilter
|
||||
permission_classes = [IsAuthenticated]
|
||||
search_fields = [
|
||||
"name",
|
||||
"codename",
|
||||
"content_type__model",
|
||||
"content_type__app_label",
|
||||
|
39
authentik/rbac/migrations/0005_initialpermissions.py
Normal file
39
authentik/rbac/migrations/0005_initialpermissions.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Generated by Django 5.0.13 on 2025-04-07 13:05
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("auth", "0012_alter_user_first_name_max_length"),
|
||||
("authentik_rbac", "0004_alter_systempermission_options"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="InitialPermissions",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.AutoField(
|
||||
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
|
||||
),
|
||||
),
|
||||
("name", models.TextField(max_length=150, unique=True)),
|
||||
("mode", models.CharField(choices=[("user", "User"), ("role", "Role")])),
|
||||
("permissions", models.ManyToManyField(blank=True, to="auth.permission")),
|
||||
(
|
||||
"role",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="authentik_rbac.role"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Initial Permissions",
|
||||
"verbose_name_plural": "Initial Permissions",
|
||||
},
|
||||
),
|
||||
]
|
@ -3,6 +3,7 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.auth.management import _get_all_permissions
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.db import models
|
||||
from django.db.transaction import atomic
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@ -75,6 +76,35 @@ class Role(SerializerModel):
|
||||
]
|
||||
|
||||
|
||||
class InitialPermissionsMode(models.TextChoices):
|
||||
"""Determines which entity the initial permissions are assigned to."""
|
||||
|
||||
USER = "user", _("User")
|
||||
ROLE = "role", _("Role")
|
||||
|
||||
|
||||
class InitialPermissions(SerializerModel):
|
||||
"""Assigns permissions for newly created objects."""
|
||||
|
||||
name = models.TextField(max_length=150, unique=True)
|
||||
mode = models.CharField(choices=InitialPermissionsMode.choices)
|
||||
role = models.ForeignKey(Role, on_delete=models.CASCADE)
|
||||
permissions = models.ManyToManyField(Permission, blank=True)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
from authentik.rbac.api.initial_permissions import InitialPermissionsSerializer
|
||||
|
||||
return InitialPermissionsSerializer
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Initial Permissions for Role #{self.role_id}, applying to #{self.mode}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Initial Permissions")
|
||||
verbose_name_plural = _("Initial Permissions")
|
||||
|
||||
|
||||
class SystemPermission(models.Model):
|
||||
"""System-wide permissions that are not related to any direct
|
||||
database model"""
|
||||
|
@ -1,9 +1,13 @@
|
||||
"""RBAC Permissions"""
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db.models import Model
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework.permissions import BasePermission, DjangoObjectPermissions
|
||||
from rest_framework.request import Request
|
||||
|
||||
from authentik.rbac.models import InitialPermissions, InitialPermissionsMode
|
||||
|
||||
|
||||
class ObjectPermissions(DjangoObjectPermissions):
|
||||
"""RBAC Permissions"""
|
||||
@ -51,3 +55,20 @@ def HasPermission(*perm: str) -> type[BasePermission]:
|
||||
return bool(request.user and request.user.has_perms(perm))
|
||||
|
||||
return checker
|
||||
|
||||
|
||||
# TODO: add `user: User` type annotation without circular dependencies.
|
||||
# The author of this function isn't proficient/patient enough to do it.
|
||||
def assign_initial_permissions(user, instance: Model):
|
||||
# Performance here should not be an issue, but if needed, there are many optimization routes
|
||||
initial_permissions_list = InitialPermissions.objects.filter(role__group__in=user.groups.all())
|
||||
for initial_permissions in initial_permissions_list:
|
||||
for permission in initial_permissions.permissions.all():
|
||||
if permission.content_type != ContentType.objects.get_for_model(instance):
|
||||
continue
|
||||
assign_to = (
|
||||
user
|
||||
if initial_permissions.mode == InitialPermissionsMode.USER
|
||||
else initial_permissions.role.group
|
||||
)
|
||||
assign_perm(permission, assign_to, instance)
|
||||
|
116
authentik/rbac/tests/test_initial_permissions.py
Normal file
116
authentik/rbac/tests/test_initial_permissions.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""Test InitialPermissions"""
|
||||
|
||||
from django.contrib.auth.models import Permission
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.rbac.models import InitialPermissions, InitialPermissionsMode, Role
|
||||
from authentik.stages.dummy.models import DummyStage
|
||||
|
||||
|
||||
class TestInitialPermissions(APITestCase):
|
||||
"""Test InitialPermissions"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_user()
|
||||
self.same_role_user = create_test_user()
|
||||
self.different_role_user = create_test_user()
|
||||
|
||||
self.role = Role.objects.create(name=generate_id())
|
||||
self.different_role = Role.objects.create(name=generate_id())
|
||||
|
||||
self.group = Group.objects.create(name=generate_id())
|
||||
self.different_group = Group.objects.create(name=generate_id())
|
||||
|
||||
self.group.roles.add(self.role)
|
||||
self.group.users.add(self.user, self.same_role_user)
|
||||
self.different_group.roles.add(self.different_role)
|
||||
self.different_group.users.add(self.different_role_user)
|
||||
|
||||
self.ip = InitialPermissions.objects.create(
|
||||
name=generate_id(), mode=InitialPermissionsMode.USER, role=self.role
|
||||
)
|
||||
self.view_role = Permission.objects.filter(codename="view_role").first()
|
||||
self.ip.permissions.add(self.view_role)
|
||||
|
||||
assign_perm("authentik_rbac.add_role", self.user)
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_different_role(self):
|
||||
"""InitialPermissions for different role does nothing"""
|
||||
self.ip.role = self.different_role
|
||||
self.ip.save()
|
||||
|
||||
self.client.post(reverse("authentik_api:roles-list"), {"name": "test-role"})
|
||||
|
||||
role = Role.objects.filter(name="test-role").first()
|
||||
self.assertFalse(self.user.has_perm("authentik_rbac.view_role", role))
|
||||
|
||||
def test_different_model(self):
|
||||
"""InitialPermissions for different model does nothing"""
|
||||
assign_perm("authentik_stages_dummy.add_dummystage", self.user)
|
||||
|
||||
self.client.post(
|
||||
reverse("authentik_api:stages-dummy-list"), {"name": "test-stage", "throw-error": False}
|
||||
)
|
||||
|
||||
role = Role.objects.filter(name="test-role").first()
|
||||
self.assertFalse(self.user.has_perm("authentik_rbac.view_role", role))
|
||||
stage = DummyStage.objects.filter(name="test-stage").first()
|
||||
self.assertFalse(self.user.has_perm("authentik_stages_dummy.view_dummystage", stage))
|
||||
|
||||
def test_mode_user(self):
|
||||
"""InitialPermissions adds user permission in user mode"""
|
||||
self.client.post(reverse("authentik_api:roles-list"), {"name": "test-role"})
|
||||
|
||||
role = Role.objects.filter(name="test-role").first()
|
||||
self.assertTrue(self.user.has_perm("authentik_rbac.view_role", role))
|
||||
self.assertFalse(self.same_role_user.has_perm("authentik_rbac.view_role", role))
|
||||
|
||||
def test_mode_role(self):
|
||||
"""InitialPermissions adds role permission in role mode"""
|
||||
self.ip.mode = InitialPermissionsMode.ROLE
|
||||
self.ip.save()
|
||||
|
||||
self.client.post(reverse("authentik_api:roles-list"), {"name": "test-role"})
|
||||
|
||||
role = Role.objects.filter(name="test-role").first()
|
||||
self.assertTrue(self.user.has_perm("authentik_rbac.view_role", role))
|
||||
self.assertTrue(self.same_role_user.has_perm("authentik_rbac.view_role", role))
|
||||
|
||||
def test_many_permissions(self):
|
||||
"""InitialPermissions can add multiple permissions"""
|
||||
change_role = Permission.objects.filter(codename="change_role").first()
|
||||
self.ip.permissions.add(change_role)
|
||||
|
||||
self.client.post(reverse("authentik_api:roles-list"), {"name": "test-role"})
|
||||
|
||||
role = Role.objects.filter(name="test-role").first()
|
||||
self.assertTrue(self.user.has_perm("authentik_rbac.view_role", role))
|
||||
self.assertTrue(self.user.has_perm("authentik_rbac.change_role", role))
|
||||
|
||||
def test_permissions_separated_by_role(self):
|
||||
"""When the triggering user is part of two different roles with InitialPermissions in role
|
||||
mode, it only adds permissions to the relevant role."""
|
||||
self.ip.mode = InitialPermissionsMode.ROLE
|
||||
self.ip.save()
|
||||
different_ip = InitialPermissions.objects.create(
|
||||
name=generate_id(), mode=InitialPermissionsMode.ROLE, role=self.different_role
|
||||
)
|
||||
change_role = Permission.objects.filter(codename="change_role").first()
|
||||
different_ip.permissions.add(change_role)
|
||||
self.different_group.users.add(self.user)
|
||||
|
||||
self.client.post(reverse("authentik_api:roles-list"), {"name": "test-role"})
|
||||
|
||||
role = Role.objects.filter(name="test-role").first()
|
||||
self.assertTrue(self.user.has_perm("authentik_rbac.view_role", role))
|
||||
self.assertTrue(self.same_role_user.has_perm("authentik_rbac.view_role", role))
|
||||
self.assertFalse(self.different_role_user.has_perm("authentik_rbac.view_role", role))
|
||||
self.assertTrue(self.user.has_perm("authentik_rbac.change_role", role))
|
||||
self.assertFalse(self.same_role_user.has_perm("authentik_rbac.change_role", role))
|
||||
self.assertTrue(self.different_role_user.has_perm("authentik_rbac.change_role", role))
|
@ -1,5 +1,6 @@
|
||||
"""RBAC API urls"""
|
||||
|
||||
from authentik.rbac.api.initial_permissions import InitialPermissionsViewSet
|
||||
from authentik.rbac.api.rbac import RBACPermissionViewSet
|
||||
from authentik.rbac.api.rbac_assigned_by_roles import RoleAssignedPermissionViewSet
|
||||
from authentik.rbac.api.rbac_assigned_by_users import UserAssignedPermissionViewSet
|
||||
@ -21,5 +22,6 @@ api_urlpatterns = [
|
||||
("rbac/permissions/users", UserPermissionViewSet, "permissions-users"),
|
||||
("rbac/permissions/roles", RolePermissionViewSet, "permissions-roles"),
|
||||
("rbac/permissions", RBACPermissionViewSet),
|
||||
("rbac/roles", RoleViewSet),
|
||||
("rbac/roles", RoleViewSet, "roles"),
|
||||
("rbac/initial_permissions", InitialPermissionsViewSet, "initial-permissions"),
|
||||
]
|
||||
|
@ -50,7 +50,7 @@ class TestRecovery(TestCase):
|
||||
)
|
||||
token = Token.objects.get(intent=TokenIntents.INTENT_RECOVERY, user=self.user)
|
||||
self.client.get(reverse("authentik_recovery:use-token", kwargs={"key": token.key}))
|
||||
self.assertEqual(int(self.client.session["_auth_user_id"]), token.user.pk)
|
||||
self.assertEqual(self.client.session["authenticatedsession"].user.pk, token.user.pk)
|
||||
|
||||
def test_recovery_view_invalid(self):
|
||||
"""Test recovery view with invalid token"""
|
||||
|
@ -1,8 +1,12 @@
|
||||
"""ASGI middleware"""
|
||||
|
||||
from channels.auth import UserLazyObject
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.middleware import BaseMiddleware
|
||||
from channels.sessions import CookieMiddleware
|
||||
from channels.sessions import InstanceSessionWrapper as UpstreamInstanceSessionWrapper
|
||||
from channels.sessions import SessionMiddleware as UpstreamSessionMiddleware
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
from authentik.root.middleware import SessionMiddleware as HTTPSessionMiddleware
|
||||
|
||||
@ -33,3 +37,48 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
||||
await wrapper.resolve_session()
|
||||
|
||||
return await self.inner(wrapper.scope, receive, wrapper.send)
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def get_user(scope):
|
||||
"""
|
||||
Return the user model instance associated with the given scope.
|
||||
If no user is retrieved, return an instance of `AnonymousUser`.
|
||||
"""
|
||||
if "session" not in scope:
|
||||
raise ValueError(
|
||||
"Cannot find session in scope. You should wrap your consumer in SessionMiddleware."
|
||||
)
|
||||
user = None
|
||||
if (authenticated_session := scope["session"].get("authenticated_session", None)) is not None:
|
||||
user = authenticated_session.user
|
||||
return user or AnonymousUser()
|
||||
|
||||
|
||||
class AuthMiddleware(BaseMiddleware):
|
||||
def populate_scope(self, scope):
|
||||
# Make sure we have a session
|
||||
if "session" not in scope:
|
||||
raise ValueError(
|
||||
"AuthMiddleware cannot find session in scope. SessionMiddleware must be above it."
|
||||
)
|
||||
# Add it to the scope if it's not there already
|
||||
if "user" not in scope:
|
||||
scope["user"] = UserLazyObject()
|
||||
|
||||
async def resolve_scope(self, scope):
|
||||
scope["user"]._wrapped = await get_user(scope)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
scope = dict(scope)
|
||||
# Scope injection/mutation per this middleware's needs.
|
||||
self.populate_scope(scope)
|
||||
# Grab the finalized/resolved scope
|
||||
await self.resolve_scope(scope)
|
||||
|
||||
return await super().__call__(scope, receive, send)
|
||||
|
||||
|
||||
# Handy shortcut for applying all three layers at once
|
||||
def AuthMiddlewareStack(inner):
|
||||
return CookieMiddleware(SessionMiddleware(AuthMiddleware(inner)))
|
||||
|
@ -49,7 +49,7 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def decode_session_key(key: str) -> str:
|
||||
def decode_session_key(key: str | None) -> str | None:
|
||||
"""Decode raw session cookie, and parse JWT"""
|
||||
# We need to support the standard django format of just a session key
|
||||
# for testing setups, where the session is directly set
|
||||
@ -64,7 +64,11 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
||||
def process_request(self, request: HttpRequest):
|
||||
raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
|
||||
session_key = SessionMiddleware.decode_session_key(raw_session)
|
||||
request.session = self.SessionStore(session_key)
|
||||
request.session = self.SessionStore(
|
||||
session_key,
|
||||
last_ip=ClientIPMiddleware.get_client_ip(request),
|
||||
last_user_agent=request.META.get("HTTP_USER_AGENT", ""),
|
||||
)
|
||||
|
||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||
"""
|
||||
|
@ -1,23 +0,0 @@
|
||||
"""
|
||||
Module for abstract serializer/unserializer base classes.
|
||||
"""
|
||||
|
||||
import pickle # nosec
|
||||
|
||||
|
||||
class PickleSerializer:
|
||||
"""
|
||||
Simple wrapper around pickle to be used in signing.dumps()/loads() and
|
||||
cache backends.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol=None):
|
||||
self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
|
||||
|
||||
def dumps(self, obj):
|
||||
"""Pickle data to be stored in redis"""
|
||||
return pickle.dumps(obj, self.protocol)
|
||||
|
||||
def loads(self, data):
|
||||
"""Unpickle data to be loaded from redis"""
|
||||
return pickle.loads(data) # nosec
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user