Compare commits
3 Commits
version-20
...
version/20
Author | SHA1 | Date | |
---|---|---|---|
63cfbb721c | |||
2b74a1f03b | |||
093573f89a |
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 2023.3.1
|
current_version = 2023.2.3
|
||||||
tag = True
|
tag = True
|
||||||
commit = True
|
commit = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
@ -7,11 +7,8 @@ charset = utf-8
|
|||||||
trim_trailing_whitespace = true
|
trim_trailing_whitespace = true
|
||||||
insert_final_newline = true
|
insert_final_newline = true
|
||||||
|
|
||||||
[*.html]
|
[html]
|
||||||
indent_size = 2
|
indent_size = 2
|
||||||
|
|
||||||
[*.{yaml,yml}]
|
[yaml]
|
||||||
indent_size = 2
|
indent_size = 2
|
||||||
|
|
||||||
[*.go]
|
|
||||||
indent_style = tab
|
|
||||||
|
2
.github/workflows/ghcr-retention.yml
vendored
2
.github/workflows/ghcr-retention.yml
vendored
@ -11,7 +11,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Delete 'dev' containers older than a week
|
- name: Delete 'dev' containers older than a week
|
||||||
uses: snok/container-retention-policy@v2
|
uses: snok/container-retention-policy@v1
|
||||||
with:
|
with:
|
||||||
image-names: dev-server,dev-ldap,dev-proxy
|
image-names: dev-server,dev-ldap,dev-proxy
|
||||||
cut-off: One week ago UTC
|
cut-off: One week ago UTC
|
||||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -200,6 +200,3 @@ media/
|
|||||||
.idea/
|
.idea/
|
||||||
/gen-*/
|
/gen-*/
|
||||||
data/
|
data/
|
||||||
|
|
||||||
# Local Netlify folder
|
|
||||||
.netlify
|
|
||||||
|
20
.vscode/extensions.json
vendored
20
.vscode/extensions.json
vendored
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"recommendations": [
|
|
||||||
"EditorConfig.EditorConfig",
|
|
||||||
"bashmish.es6-string-css",
|
|
||||||
"bpruitt-goddard.mermaid-markdown-syntax-highlighting",
|
|
||||||
"dbaeumer.vscode-eslint",
|
|
||||||
"esbenp.prettier-vscode",
|
|
||||||
"golang.go",
|
|
||||||
"Gruntfuggly.todo-tree",
|
|
||||||
"mechatroner.rainbow-csv",
|
|
||||||
"ms-python.black-formatter",
|
|
||||||
"ms-python.isort",
|
|
||||||
"ms-python.pylint",
|
|
||||||
"ms-python.python",
|
|
||||||
"ms-python.vscode-pylance",
|
|
||||||
"redhat.vscode-yaml",
|
|
||||||
"Tobermory.es6-string-html",
|
|
||||||
"unifiedjs.vscode-mdx"
|
|
||||||
]
|
|
||||||
}
|
|
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -16,8 +16,7 @@
|
|||||||
"passwordless",
|
"passwordless",
|
||||||
"kubernetes",
|
"kubernetes",
|
||||||
"sso",
|
"sso",
|
||||||
"slo",
|
"slo"
|
||||||
"scim",
|
|
||||||
],
|
],
|
||||||
"python.linting.pylintEnabled": true,
|
"python.linting.pylintEnabled": true,
|
||||||
"todo-tree.tree.showCountsInTree": true,
|
"todo-tree.tree.showCountsInTree": true,
|
||||||
|
@ -154,19 +154,12 @@ While the prerequisites above must be satisfied prior to having your pull reques
|
|||||||
|
|
||||||
## Styleguides
|
## Styleguides
|
||||||
|
|
||||||
### PR naming
|
|
||||||
|
|
||||||
- Use the format of `<package>: <verb> <description>`
|
|
||||||
- See [here](#authentik-packages) for `package`
|
|
||||||
- Example: `providers/saml2: fix parsing of requests`
|
|
||||||
|
|
||||||
### Git Commit Messages
|
### Git Commit Messages
|
||||||
|
|
||||||
- Use the format of `<package>: <verb> <description>`
|
- Use the format of `<package>: <verb> <description>`
|
||||||
- See [here](#authentik-packages) for `package`
|
- See [here](#authentik-packages) for `package`
|
||||||
- Example: `providers/saml2: fix parsing of requests`
|
- Example: `providers/saml2: fix parsing of requests`
|
||||||
- Reference issues and pull requests liberally after the first line
|
- Reference issues and pull requests liberally after the first line
|
||||||
- Naming of commits within a PR does not need to adhere to the guidelines as we squash merge PRs
|
|
||||||
|
|
||||||
### Python Styleguide
|
### Python Styleguide
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ RUN pip install --no-cache-dir poetry && \
|
|||||||
poetry export -f requirements.txt --dev --output requirements-dev.txt
|
poetry export -f requirements.txt --dev --output requirements-dev.txt
|
||||||
|
|
||||||
# Stage 4: Build go proxy
|
# Stage 4: Build go proxy
|
||||||
FROM docker.io/golang:1.20.2-bullseye AS go-builder
|
FROM docker.io/golang:1.20.1-bullseye AS go-builder
|
||||||
|
|
||||||
WORKDIR /work
|
WORKDIR /work
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
COPY ./authentik/ /authentik
|
COPY ./authentik/ /authentik
|
||||||
COPY ./pyproject.toml /
|
COPY ./pyproject.toml /
|
||||||
COPY ./schemas /schemas
|
COPY ./xml /xml
|
||||||
COPY ./locale /locale
|
COPY ./locale /locale
|
||||||
COPY ./tests /tests
|
COPY ./tests /tests
|
||||||
COPY ./manage.py /
|
COPY ./manage.py /
|
||||||
|
@ -6,8 +6,8 @@ Authentik takes security very seriously. We follow the rules of [responsible dis
|
|||||||
|
|
||||||
| Version | Supported |
|
| Version | Supported |
|
||||||
| --------- | ------------------ |
|
| --------- | ------------------ |
|
||||||
| 2023.2.x | :white_check_mark: |
|
| 2022.12.x | :white_check_mark: |
|
||||||
| 2023.3.x | :white_check_mark: |
|
| 2023.1.x | :white_check_mark: |
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from os import environ
|
from os import environ
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
__version__ = "2023.3.1"
|
__version__ = "2023.2.3"
|
||||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ from authentik.blueprints.tests import reconcile_app
|
|||||||
from authentik.core.models import Group, User
|
from authentik.core.models import Group, User
|
||||||
from authentik.core.tasks import clean_expired_models
|
from authentik.core.tasks import clean_expired_models
|
||||||
from authentik.events.monitored_tasks import TaskResultStatus
|
from authentik.events.monitored_tasks import TaskResultStatus
|
||||||
from authentik.lib.generators import generate_id
|
|
||||||
|
|
||||||
|
|
||||||
class TestAdminAPI(TestCase):
|
class TestAdminAPI(TestCase):
|
||||||
@ -17,8 +16,8 @@ class TestAdminAPI(TestCase):
|
|||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.user = User.objects.create(username=generate_id())
|
self.user = User.objects.create(username="test-user")
|
||||||
self.group = Group.objects.create(name=generate_id(), is_superuser=True)
|
self.group = Group.objects.create(name="superusers", is_superuser=True)
|
||||||
self.group.users.add(self.user)
|
self.group.users.add(self.user)
|
||||||
self.group.save()
|
self.group.save()
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
|
@ -4,7 +4,6 @@ from base64 import b64encode
|
|||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import timezone
|
|
||||||
from rest_framework.exceptions import AuthenticationFailed
|
from rest_framework.exceptions import AuthenticationFailed
|
||||||
|
|
||||||
from authentik.api.authentication import bearer_auth
|
from authentik.api.authentication import bearer_auth
|
||||||
@ -69,7 +68,6 @@ class TestAPIAuth(TestCase):
|
|||||||
user=create_test_admin_user(),
|
user=create_test_admin_user(),
|
||||||
provider=provider,
|
provider=provider,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
_scope=SCOPE_AUTHENTIK_API,
|
_scope=SCOPE_AUTHENTIK_API,
|
||||||
_id_token=json.dumps({}),
|
_id_token=json.dumps({}),
|
||||||
)
|
)
|
||||||
@ -84,7 +82,6 @@ class TestAPIAuth(TestCase):
|
|||||||
user=create_test_admin_user(),
|
user=create_test_admin_user(),
|
||||||
provider=provider,
|
provider=provider,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
_scope="",
|
_scope="",
|
||||||
_id_token=json.dumps({}),
|
_id_token=json.dumps({}),
|
||||||
)
|
)
|
||||||
|
@ -4,7 +4,6 @@ from guardian.shortcuts import assign_perm
|
|||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from authentik.core.models import Application, User
|
from authentik.core.models import Application, User
|
||||||
from authentik.lib.generators import generate_id
|
|
||||||
|
|
||||||
|
|
||||||
class TestAPIDecorators(APITestCase):
|
class TestAPIDecorators(APITestCase):
|
||||||
@ -17,7 +16,7 @@ class TestAPIDecorators(APITestCase):
|
|||||||
def test_obj_perm_denied(self):
|
def test_obj_perm_denied(self):
|
||||||
"""Test object perm denied"""
|
"""Test object perm denied"""
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
app = Application.objects.create(name=generate_id(), slug=generate_id())
|
app = Application.objects.create(name="denied", slug="denied")
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
|
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
|
||||||
)
|
)
|
||||||
@ -26,7 +25,7 @@ class TestAPIDecorators(APITestCase):
|
|||||||
def test_other_perm_denied(self):
|
def test_other_perm_denied(self):
|
||||||
"""Test other perm denied"""
|
"""Test other perm denied"""
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
app = Application.objects.create(name=generate_id(), slug=generate_id())
|
app = Application.objects.create(name="denied", slug="denied")
|
||||||
assign_perm("authentik_core.view_application", self.user, app)
|
assign_perm("authentik_core.view_application", self.user, app)
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
|
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
|
||||||
|
@ -58,8 +58,6 @@ from authentik.providers.oauth2.api.tokens import (
|
|||||||
from authentik.providers.proxy.api import ProxyOutpostConfigViewSet, ProxyProviderViewSet
|
from authentik.providers.proxy.api import ProxyOutpostConfigViewSet, ProxyProviderViewSet
|
||||||
from authentik.providers.saml.api.property_mapping import SAMLPropertyMappingViewSet
|
from authentik.providers.saml.api.property_mapping import SAMLPropertyMappingViewSet
|
||||||
from authentik.providers.saml.api.providers import SAMLProviderViewSet
|
from authentik.providers.saml.api.providers import SAMLProviderViewSet
|
||||||
from authentik.providers.scim.api.property_mapping import SCIMMappingViewSet
|
|
||||||
from authentik.providers.scim.api.providers import SCIMProviderViewSet
|
|
||||||
from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet
|
from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet
|
||||||
from authentik.sources.oauth.api.source import OAuthSourceViewSet
|
from authentik.sources.oauth.api.source import OAuthSourceViewSet
|
||||||
from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet
|
from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet
|
||||||
@ -165,7 +163,6 @@ router.register("providers/ldap", LDAPProviderViewSet)
|
|||||||
router.register("providers/proxy", ProxyProviderViewSet)
|
router.register("providers/proxy", ProxyProviderViewSet)
|
||||||
router.register("providers/oauth2", OAuth2ProviderViewSet)
|
router.register("providers/oauth2", OAuth2ProviderViewSet)
|
||||||
router.register("providers/saml", SAMLProviderViewSet)
|
router.register("providers/saml", SAMLProviderViewSet)
|
||||||
router.register("providers/scim", SCIMProviderViewSet)
|
|
||||||
|
|
||||||
router.register("oauth2/authorization_codes", AuthorizationCodeViewSet)
|
router.register("oauth2/authorization_codes", AuthorizationCodeViewSet)
|
||||||
router.register("oauth2/refresh_tokens", RefreshTokenViewSet)
|
router.register("oauth2/refresh_tokens", RefreshTokenViewSet)
|
||||||
@ -176,7 +173,6 @@ router.register("propertymappings/ldap", LDAPPropertyMappingViewSet)
|
|||||||
router.register("propertymappings/saml", SAMLPropertyMappingViewSet)
|
router.register("propertymappings/saml", SAMLPropertyMappingViewSet)
|
||||||
router.register("propertymappings/scope", ScopeMappingViewSet)
|
router.register("propertymappings/scope", ScopeMappingViewSet)
|
||||||
router.register("propertymappings/notification", NotificationWebhookMappingViewSet)
|
router.register("propertymappings/notification", NotificationWebhookMappingViewSet)
|
||||||
router.register("propertymappings/scim", SCIMMappingViewSet)
|
|
||||||
|
|
||||||
router.register("authenticators/all", DeviceViewSet, basename="device")
|
router.register("authenticators/all", DeviceViewSet, basename="device")
|
||||||
router.register("authenticators/duo", DuoDeviceViewSet)
|
router.register("authenticators/duo", DuoDeviceViewSet)
|
||||||
|
@ -55,11 +55,11 @@ class AuthentikBlueprintsConfig(ManagedAppConfig):
|
|||||||
"""Load v1 tasks"""
|
"""Load v1 tasks"""
|
||||||
self.import_module("authentik.blueprints.v1.tasks")
|
self.import_module("authentik.blueprints.v1.tasks")
|
||||||
|
|
||||||
def reconcile_blueprints_discovery(self):
|
def reconcile_blueprints_discover(self):
|
||||||
"""Run blueprint discovery"""
|
"""Run blueprint discovery"""
|
||||||
from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
|
from authentik.blueprints.v1.tasks import blueprints_discover, clear_failed_blueprints
|
||||||
|
|
||||||
blueprints_discovery.delay()
|
blueprints_discover.delay()
|
||||||
clear_failed_blueprints.delay()
|
clear_failed_blueprints.delay()
|
||||||
|
|
||||||
def import_models(self):
|
def import_models(self):
|
||||||
|
@ -19,8 +19,10 @@ class Command(BaseCommand):
|
|||||||
for blueprint_path in options.get("blueprints", []):
|
for blueprint_path in options.get("blueprints", []):
|
||||||
content = BlueprintInstance(path=blueprint_path).retrieve()
|
content = BlueprintInstance(path=blueprint_path).retrieve()
|
||||||
importer = Importer(content)
|
importer = Importer(content)
|
||||||
valid, _ = importer.validate()
|
valid, logs = importer.validate()
|
||||||
if not valid:
|
if not valid:
|
||||||
|
for log in logs:
|
||||||
|
getattr(LOGGER, log.pop("log_level"))(**log)
|
||||||
self.stderr.write("blueprint invalid")
|
self.stderr.write("blueprint invalid")
|
||||||
sys_exit(1)
|
sys_exit(1)
|
||||||
importer.apply()
|
importer.apply()
|
||||||
|
@ -5,7 +5,7 @@ from authentik.lib.utils.time import fqdn_rand
|
|||||||
|
|
||||||
CELERY_BEAT_SCHEDULE = {
|
CELERY_BEAT_SCHEDULE = {
|
||||||
"blueprints_v1_discover": {
|
"blueprints_v1_discover": {
|
||||||
"task": "authentik.blueprints.v1.tasks.blueprints_discovery",
|
"task": "authentik.blueprints.v1.tasks.blueprints_discover",
|
||||||
"schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"),
|
"schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"),
|
||||||
"options": {"queue": "authentik_scheduled"},
|
"options": {"queue": "authentik_scheduled"},
|
||||||
},
|
},
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Blueprint helpers"""
|
"""Blueprint helpers"""
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from django.apps import apps
|
from django.apps import apps
|
||||||
@ -44,3 +45,13 @@ def reconcile_app(app_name: str):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return wrapper_outer
|
return wrapper_outer
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml_fixture(path: str, **kwargs) -> str:
|
||||||
|
"""Load yaml fixture, optionally formatting it with kwargs"""
|
||||||
|
with open(Path(__file__).resolve().parent / Path(path), "r", encoding="utf-8") as _fixture:
|
||||||
|
fixture = _fixture.read()
|
||||||
|
try:
|
||||||
|
return fixture % kwargs
|
||||||
|
except TypeError:
|
||||||
|
return fixture
|
||||||
|
@ -3,12 +3,12 @@ from os import environ
|
|||||||
|
|
||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
|
from authentik.blueprints.tests import load_yaml_fixture
|
||||||
from authentik.blueprints.v1.exporter import FlowExporter
|
from authentik.blueprints.v1.exporter import FlowExporter
|
||||||
from authentik.blueprints.v1.importer import Importer, transaction_rollback
|
from authentik.blueprints.v1.importer import Importer, transaction_rollback
|
||||||
from authentik.core.models import Group
|
from authentik.core.models import Group
|
||||||
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
|
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.tests.utils import load_fixture
|
|
||||||
from authentik.policies.expression.models import ExpressionPolicy
|
from authentik.policies.expression.models import ExpressionPolicy
|
||||||
from authentik.policies.models import PolicyBinding
|
from authentik.policies.models import PolicyBinding
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
@ -113,14 +113,14 @@ class TestBlueprintsV1(TransactionTestCase):
|
|||||||
"""Test export and import it twice"""
|
"""Test export and import it twice"""
|
||||||
count_initial = Prompt.objects.filter(field_key="username").count()
|
count_initial = Prompt.objects.filter(field_key="username").count()
|
||||||
|
|
||||||
importer = Importer(load_fixture("fixtures/static_prompt_export.yaml"))
|
importer = Importer(load_yaml_fixture("fixtures/static_prompt_export.yaml"))
|
||||||
self.assertTrue(importer.validate()[0])
|
self.assertTrue(importer.validate()[0])
|
||||||
self.assertTrue(importer.apply())
|
self.assertTrue(importer.apply())
|
||||||
|
|
||||||
count_before = Prompt.objects.filter(field_key="username").count()
|
count_before = Prompt.objects.filter(field_key="username").count()
|
||||||
self.assertEqual(count_initial + 1, count_before)
|
self.assertEqual(count_initial + 1, count_before)
|
||||||
|
|
||||||
importer = Importer(load_fixture("fixtures/static_prompt_export.yaml"))
|
importer = Importer(load_yaml_fixture("fixtures/static_prompt_export.yaml"))
|
||||||
self.assertTrue(importer.apply())
|
self.assertTrue(importer.apply())
|
||||||
|
|
||||||
self.assertEqual(Prompt.objects.filter(field_key="username").count(), count_before)
|
self.assertEqual(Prompt.objects.filter(field_key="username").count(), count_before)
|
||||||
@ -130,7 +130,7 @@ class TestBlueprintsV1(TransactionTestCase):
|
|||||||
ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete()
|
ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete()
|
||||||
Group.objects.filter(name="test").delete()
|
Group.objects.filter(name="test").delete()
|
||||||
environ["foo"] = generate_id()
|
environ["foo"] = generate_id()
|
||||||
importer = Importer(load_fixture("fixtures/tags.yaml"), {"bar": "baz"})
|
importer = Importer(load_yaml_fixture("fixtures/tags.yaml"), {"bar": "baz"})
|
||||||
self.assertTrue(importer.validate()[0])
|
self.assertTrue(importer.validate()[0])
|
||||||
self.assertTrue(importer.apply())
|
self.assertTrue(importer.apply())
|
||||||
policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first()
|
policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first()
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Test blueprints v1"""
|
"""Test blueprints v1"""
|
||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
|
from authentik.blueprints.tests import load_yaml_fixture
|
||||||
from authentik.blueprints.v1.importer import Importer
|
from authentik.blueprints.v1.importer import Importer
|
||||||
from authentik.flows.models import Flow
|
from authentik.flows.models import Flow
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.tests.utils import load_fixture
|
|
||||||
|
|
||||||
|
|
||||||
class TestBlueprintsV1Conditions(TransactionTestCase):
|
class TestBlueprintsV1Conditions(TransactionTestCase):
|
||||||
@ -14,7 +14,7 @@ class TestBlueprintsV1Conditions(TransactionTestCase):
|
|||||||
"""Test conditions fulfilled"""
|
"""Test conditions fulfilled"""
|
||||||
flow_slug1 = generate_id()
|
flow_slug1 = generate_id()
|
||||||
flow_slug2 = generate_id()
|
flow_slug2 = generate_id()
|
||||||
import_yaml = load_fixture(
|
import_yaml = load_yaml_fixture(
|
||||||
"fixtures/conditions_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2
|
"fixtures/conditions_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class TestBlueprintsV1Conditions(TransactionTestCase):
|
|||||||
"""Test conditions not fulfilled"""
|
"""Test conditions not fulfilled"""
|
||||||
flow_slug1 = generate_id()
|
flow_slug1 = generate_id()
|
||||||
flow_slug2 = generate_id()
|
flow_slug2 = generate_id()
|
||||||
import_yaml = load_fixture(
|
import_yaml = load_yaml_fixture(
|
||||||
"fixtures/conditions_not_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2
|
"fixtures/conditions_not_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Test blueprints v1"""
|
"""Test blueprints v1"""
|
||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
|
from authentik.blueprints.tests import load_yaml_fixture
|
||||||
from authentik.blueprints.v1.importer import Importer
|
from authentik.blueprints.v1.importer import Importer
|
||||||
from authentik.flows.models import Flow
|
from authentik.flows.models import Flow
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.tests.utils import load_fixture
|
|
||||||
|
|
||||||
|
|
||||||
class TestBlueprintsV1State(TransactionTestCase):
|
class TestBlueprintsV1State(TransactionTestCase):
|
||||||
@ -13,7 +13,7 @@ class TestBlueprintsV1State(TransactionTestCase):
|
|||||||
def test_state_present(self):
|
def test_state_present(self):
|
||||||
"""Test state present"""
|
"""Test state present"""
|
||||||
flow_slug = generate_id()
|
flow_slug = generate_id()
|
||||||
import_yaml = load_fixture("fixtures/state_present.yaml", id=flow_slug)
|
import_yaml = load_yaml_fixture("fixtures/state_present.yaml", id=flow_slug)
|
||||||
|
|
||||||
importer = Importer(import_yaml)
|
importer = Importer(import_yaml)
|
||||||
self.assertTrue(importer.validate()[0])
|
self.assertTrue(importer.validate()[0])
|
||||||
@ -39,7 +39,7 @@ class TestBlueprintsV1State(TransactionTestCase):
|
|||||||
def test_state_created(self):
|
def test_state_created(self):
|
||||||
"""Test state created"""
|
"""Test state created"""
|
||||||
flow_slug = generate_id()
|
flow_slug = generate_id()
|
||||||
import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug)
|
import_yaml = load_yaml_fixture("fixtures/state_created.yaml", id=flow_slug)
|
||||||
|
|
||||||
importer = Importer(import_yaml)
|
importer = Importer(import_yaml)
|
||||||
self.assertTrue(importer.validate()[0])
|
self.assertTrue(importer.validate()[0])
|
||||||
@ -65,7 +65,7 @@ class TestBlueprintsV1State(TransactionTestCase):
|
|||||||
def test_state_absent(self):
|
def test_state_absent(self):
|
||||||
"""Test state absent"""
|
"""Test state absent"""
|
||||||
flow_slug = generate_id()
|
flow_slug = generate_id()
|
||||||
import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug)
|
import_yaml = load_yaml_fixture("fixtures/state_created.yaml", id=flow_slug)
|
||||||
|
|
||||||
importer = Importer(import_yaml)
|
importer = Importer(import_yaml)
|
||||||
self.assertTrue(importer.validate()[0])
|
self.assertTrue(importer.validate()[0])
|
||||||
@ -74,7 +74,7 @@ class TestBlueprintsV1State(TransactionTestCase):
|
|||||||
flow: Flow = Flow.objects.filter(slug=flow_slug).first()
|
flow: Flow = Flow.objects.filter(slug=flow_slug).first()
|
||||||
self.assertEqual(flow.slug, flow_slug)
|
self.assertEqual(flow.slug, flow_slug)
|
||||||
|
|
||||||
import_yaml = load_fixture("fixtures/state_absent.yaml", id=flow_slug)
|
import_yaml = load_yaml_fixture("fixtures/state_absent.yaml", id=flow_slug)
|
||||||
importer = Importer(import_yaml)
|
importer = Importer(import_yaml)
|
||||||
self.assertTrue(importer.validate()[0])
|
self.assertTrue(importer.validate()[0])
|
||||||
self.assertTrue(importer.apply())
|
self.assertTrue(importer.apply())
|
||||||
|
@ -6,7 +6,7 @@ from django.test import TransactionTestCase
|
|||||||
from yaml import dump
|
from yaml import dump
|
||||||
|
|
||||||
from authentik.blueprints.models import BlueprintInstance, BlueprintInstanceStatus
|
from authentik.blueprints.models import BlueprintInstance, BlueprintInstanceStatus
|
||||||
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_discovery, blueprints_find
|
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_discover, blueprints_find
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
|||||||
file.seek(0)
|
file.seek(0)
|
||||||
file_hash = sha512(file.read().encode()).hexdigest()
|
file_hash = sha512(file.read().encode()).hexdigest()
|
||||||
file.flush()
|
file.flush()
|
||||||
blueprints_discovery() # pylint: disable=no-value-for-parameter
|
blueprints_discover() # pylint: disable=no-value-for-parameter
|
||||||
instance = BlueprintInstance.objects.filter(name=blueprint_id).first()
|
instance = BlueprintInstance.objects.filter(name=blueprint_id).first()
|
||||||
self.assertEqual(instance.last_applied_hash, file_hash)
|
self.assertEqual(instance.last_applied_hash, file_hash)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -81,7 +81,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
file.flush()
|
file.flush()
|
||||||
blueprints_discovery() # pylint: disable=no-value-for-parameter
|
blueprints_discover() # pylint: disable=no-value-for-parameter
|
||||||
blueprint = BlueprintInstance.objects.filter(name="foo").first()
|
blueprint = BlueprintInstance.objects.filter(name="foo").first()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
blueprint.last_applied_hash,
|
blueprint.last_applied_hash,
|
||||||
@ -106,7 +106,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
file.flush()
|
file.flush()
|
||||||
blueprints_discovery() # pylint: disable=no-value-for-parameter
|
blueprints_discover() # pylint: disable=no-value-for-parameter
|
||||||
blueprint.refresh_from_db()
|
blueprint.refresh_from_db()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
blueprint.last_applied_hash,
|
blueprint.last_applied_hash,
|
||||||
|
@ -40,10 +40,6 @@ from authentik.lib.models import SerializerModel
|
|||||||
from authentik.outposts.models import OutpostServiceConnection
|
from authentik.outposts.models import OutpostServiceConnection
|
||||||
from authentik.policies.models import Policy, PolicyBindingModel
|
from authentik.policies.models import Policy, PolicyBindingModel
|
||||||
|
|
||||||
# Context set when the serializer is created in a blueprint context
|
|
||||||
# Update website/developer-docs/blueprints/v1/models.md when used
|
|
||||||
SERIALIZER_CONTEXT_BLUEPRINT = "blueprint_entry"
|
|
||||||
|
|
||||||
|
|
||||||
def is_model_allowed(model: type[Model]) -> bool:
|
def is_model_allowed(model: type[Model]) -> bool:
|
||||||
"""Check if model is allowed"""
|
"""Check if model is allowed"""
|
||||||
@ -162,12 +158,7 @@ class Importer:
|
|||||||
raise EntryInvalidError(f"Model {model} not allowed")
|
raise EntryInvalidError(f"Model {model} not allowed")
|
||||||
if issubclass(model, BaseMetaModel):
|
if issubclass(model, BaseMetaModel):
|
||||||
serializer_class: type[Serializer] = model.serializer()
|
serializer_class: type[Serializer] = model.serializer()
|
||||||
serializer = serializer_class(
|
serializer = serializer_class(data=entry.get_attrs(self.__import))
|
||||||
data=entry.get_attrs(self.__import),
|
|
||||||
context={
|
|
||||||
SERIALIZER_CONTEXT_BLUEPRINT: entry,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
@ -226,12 +217,7 @@ class Importer:
|
|||||||
always_merger.merge(full_data, updated_identifiers)
|
always_merger.merge(full_data, updated_identifiers)
|
||||||
serializer_kwargs["data"] = full_data
|
serializer_kwargs["data"] = full_data
|
||||||
|
|
||||||
serializer: Serializer = model().serializer(
|
serializer: Serializer = model().serializer(**serializer_kwargs)
|
||||||
context={
|
|
||||||
SERIALIZER_CONTEXT_BLUEPRINT: entry,
|
|
||||||
},
|
|
||||||
**serializer_kwargs,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
|
@ -76,7 +76,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
|
|||||||
return
|
return
|
||||||
if isinstance(event, FileCreatedEvent):
|
if isinstance(event, FileCreatedEvent):
|
||||||
LOGGER.debug("new blueprint file created, starting discovery")
|
LOGGER.debug("new blueprint file created, starting discovery")
|
||||||
blueprints_discovery.delay()
|
blueprints_discover.delay()
|
||||||
if isinstance(event, FileModifiedEvent):
|
if isinstance(event, FileModifiedEvent):
|
||||||
path = Path(event.src_path)
|
path = Path(event.src_path)
|
||||||
root = Path(CONFIG.y("blueprints_dir")).absolute()
|
root = Path(CONFIG.y("blueprints_dir")).absolute()
|
||||||
@ -134,7 +134,7 @@ def blueprints_find():
|
|||||||
throws=(DatabaseError, ProgrammingError, InternalError), base=MonitoredTask, bind=True
|
throws=(DatabaseError, ProgrammingError, InternalError), base=MonitoredTask, bind=True
|
||||||
)
|
)
|
||||||
@prefill_task
|
@prefill_task
|
||||||
def blueprints_discovery(self: MonitoredTask):
|
def blueprints_discover(self: MonitoredTask):
|
||||||
"""Find blueprints and check if they need to be created in the database"""
|
"""Find blueprints and check if they need to be created in the database"""
|
||||||
count = 0
|
count = 0
|
||||||
for blueprint in blueprints_find():
|
for blueprint in blueprints_find():
|
||||||
|
@ -37,6 +37,7 @@ from authentik.lib.utils.file import (
|
|||||||
from authentik.policies.api.exec import PolicyTestResultSerializer
|
from authentik.policies.api.exec import PolicyTestResultSerializer
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
from authentik.policies.types import PolicyResult
|
from authentik.policies.types import PolicyResult
|
||||||
|
from authentik.stages.user_login.stage import USER_LOGIN_AUTHENTICATED
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
@ -185,6 +186,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
|||||||
if superuser_full_list and request.user.is_superuser:
|
if superuser_full_list and request.user.is_superuser:
|
||||||
return super().list(request)
|
return super().list(request)
|
||||||
|
|
||||||
|
# To prevent the user from having to double login when prompt is set to login
|
||||||
|
# and the user has just signed it. This session variable is set in the UserLoginStage
|
||||||
|
# and is (quite hackily) removed from the session in applications's API's List method
|
||||||
|
self.request.session.pop(USER_LOGIN_AUTHENTICATED, None)
|
||||||
queryset = self._filter_queryset_for_list(self.get_queryset())
|
queryset = self._filter_queryset_for_list(self.get_queryset())
|
||||||
self.paginate_queryset(queryset)
|
self.paginate_queryset(queryset)
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ from authentik.core.models import Group, User
|
|||||||
class GroupMemberSerializer(ModelSerializer):
|
class GroupMemberSerializer(ModelSerializer):
|
||||||
"""Stripped down user serializer to show relevant users for groups"""
|
"""Stripped down user serializer to show relevant users for groups"""
|
||||||
|
|
||||||
|
avatar = CharField(read_only=True)
|
||||||
attributes = JSONField(validators=[is_dict], required=False)
|
attributes = JSONField(validators=[is_dict], required=False)
|
||||||
uid = CharField(read_only=True)
|
uid = CharField(read_only=True)
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ class GroupMemberSerializer(ModelSerializer):
|
|||||||
"is_active",
|
"is_active",
|
||||||
"last_login",
|
"last_login",
|
||||||
"email",
|
"email",
|
||||||
|
"avatar",
|
||||||
"attributes",
|
"attributes",
|
||||||
"uid",
|
"uid",
|
||||||
]
|
]
|
||||||
|
@ -44,9 +44,6 @@ class ProviderSerializer(ModelSerializer, MetaNameSerializer):
|
|||||||
"verbose_name_plural",
|
"verbose_name_plural",
|
||||||
"meta_model_name",
|
"meta_model_name",
|
||||||
]
|
]
|
||||||
extra_kwargs = {
|
|
||||||
"authorization_flow": {"required": True, "allow_null": False},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderViewSet(
|
class ProviderViewSet(
|
||||||
|
@ -206,6 +206,5 @@ class UserSourceConnectionViewSet(
|
|||||||
queryset = UserSourceConnection.objects.all()
|
queryset = UserSourceConnection.objects.all()
|
||||||
serializer_class = UserSourceConnectionSerializer
|
serializer_class = UserSourceConnectionSerializer
|
||||||
permission_classes = [OwnerSuperuserPermissions]
|
permission_classes = [OwnerSuperuserPermissions]
|
||||||
filterset_fields = ["user"]
|
|
||||||
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
|
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
|
||||||
ordering = ["pk"]
|
ordering = ["pk"]
|
||||||
|
@ -16,7 +16,6 @@ from rest_framework.viewsets import ModelViewSet
|
|||||||
from authentik.api.authorization import OwnerSuperuserPermissions
|
from authentik.api.authorization import OwnerSuperuserPermissions
|
||||||
from authentik.api.decorators import permission_required
|
from authentik.api.decorators import permission_required
|
||||||
from authentik.blueprints.api import ManagedSerializer
|
from authentik.blueprints.api import ManagedSerializer
|
||||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.users import UserSerializer
|
from authentik.core.api.users import UserSerializer
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
from authentik.core.api.utils import PassiveSerializer
|
||||||
@ -30,20 +29,9 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
|
|||||||
|
|
||||||
user_obj = UserSerializer(required=False, source="user", read_only=True)
|
user_obj = UserSerializer(required=False, source="user", read_only=True)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
|
|
||||||
self.fields["key"] = CharField()
|
|
||||||
|
|
||||||
def validate(self, attrs: dict[Any, str]) -> dict[Any, str]:
|
def validate(self, attrs: dict[Any, str]) -> dict[Any, str]:
|
||||||
"""Ensure only API or App password tokens are created."""
|
"""Ensure only API or App password tokens are created."""
|
||||||
request: Request = self.context.get("request")
|
request: Request = self.context["request"]
|
||||||
if not request:
|
|
||||||
if "user" not in attrs:
|
|
||||||
raise ValidationError("Missing user")
|
|
||||||
if "intent" not in attrs:
|
|
||||||
raise ValidationError("Missing intent")
|
|
||||||
else:
|
|
||||||
attrs.setdefault("user", request.user)
|
attrs.setdefault("user", request.user)
|
||||||
attrs.setdefault("intent", TokenIntents.INTENT_API)
|
attrs.setdefault("intent", TokenIntents.INTENT_API)
|
||||||
if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]:
|
if attrs.get("intent") not in [TokenIntents.INTENT_API, TokenIntents.INTENT_APP_PASSWORD]:
|
||||||
|
@ -38,7 +38,6 @@ from rest_framework.request import Request
|
|||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import (
|
from rest_framework.serializers import (
|
||||||
BooleanField,
|
BooleanField,
|
||||||
DateTimeField,
|
|
||||||
ListSerializer,
|
ListSerializer,
|
||||||
ModelSerializer,
|
ModelSerializer,
|
||||||
PrimaryKeyRelatedField,
|
PrimaryKeyRelatedField,
|
||||||
@ -68,7 +67,6 @@ from authentik.core.models import (
|
|||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from authentik.events.models import EventAction
|
from authentik.events.models import EventAction
|
||||||
from authentik.flows.exceptions import FlowNonApplicableException
|
|
||||||
from authentik.flows.models import FlowToken
|
from authentik.flows.models import FlowToken
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
|
||||||
from authentik.flows.views.executor import QS_KEY_TOKEN
|
from authentik.flows.views.executor import QS_KEY_TOKEN
|
||||||
@ -327,16 +325,12 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
|||||||
user: User = self.get_object()
|
user: User = self.get_object()
|
||||||
planner = FlowPlanner(flow)
|
planner = FlowPlanner(flow)
|
||||||
planner.allow_empty_flows = True
|
planner.allow_empty_flows = True
|
||||||
try:
|
|
||||||
plan = planner.plan(
|
plan = planner.plan(
|
||||||
self.request._request,
|
self.request._request,
|
||||||
{
|
{
|
||||||
PLAN_CONTEXT_PENDING_USER: user,
|
PLAN_CONTEXT_PENDING_USER: user,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except FlowNonApplicableException:
|
|
||||||
LOGGER.warning("Recovery flow not applicable to user")
|
|
||||||
return None, None
|
|
||||||
token, __ = FlowToken.objects.update_or_create(
|
token, __ = FlowToken.objects.update_or_create(
|
||||||
identifier=f"{user.uid}-password-reset",
|
identifier=f"{user.uid}-password-reset",
|
||||||
defaults={
|
defaults={
|
||||||
@ -359,11 +353,6 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
|||||||
{
|
{
|
||||||
"name": CharField(required=True),
|
"name": CharField(required=True),
|
||||||
"create_group": BooleanField(default=False),
|
"create_group": BooleanField(default=False),
|
||||||
"expiring": BooleanField(default=True),
|
|
||||||
"expires": DateTimeField(
|
|
||||||
required=False,
|
|
||||||
help_text="If not provided, valid for 360 days",
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
responses={
|
responses={
|
||||||
@ -384,20 +373,14 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
|||||||
"""Create a new user account that is marked as a service account"""
|
"""Create a new user account that is marked as a service account"""
|
||||||
username = request.data.get("name")
|
username = request.data.get("name")
|
||||||
create_group = request.data.get("create_group", False)
|
create_group = request.data.get("create_group", False)
|
||||||
expiring = request.data.get("expiring", True)
|
|
||||||
expires = request.data.get("expires", now() + timedelta(days=360))
|
|
||||||
|
|
||||||
with atomic():
|
with atomic():
|
||||||
try:
|
try:
|
||||||
user: User = User.objects.create(
|
user = User.objects.create(
|
||||||
username=username,
|
username=username,
|
||||||
name=username,
|
name=username,
|
||||||
attributes={USER_ATTRIBUTE_SA: True, USER_ATTRIBUTE_TOKEN_EXPIRING: expiring},
|
attributes={USER_ATTRIBUTE_SA: True, USER_ATTRIBUTE_TOKEN_EXPIRING: False},
|
||||||
path=USER_PATH_SERVICE_ACCOUNT,
|
path=USER_PATH_SERVICE_ACCOUNT,
|
||||||
)
|
)
|
||||||
user.set_unusable_password()
|
|
||||||
user.save()
|
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"username": user.username,
|
"username": user.username,
|
||||||
"user_uid": user.uid,
|
"user_uid": user.uid,
|
||||||
@ -413,8 +396,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
|||||||
identifier=slugify(f"service-account-{username}-password"),
|
identifier=slugify(f"service-account-{username}-password"),
|
||||||
intent=TokenIntents.INTENT_APP_PASSWORD,
|
intent=TokenIntents.INTENT_APP_PASSWORD,
|
||||||
user=user,
|
user=user,
|
||||||
expires=expires,
|
expires=now() + timedelta(days=360),
|
||||||
expiring=expiring,
|
|
||||||
)
|
)
|
||||||
response["token"] = token.key
|
response["token"] = token.key
|
||||||
return Response(response)
|
return Response(response)
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
"""Property Mapping Evaluator"""
|
"""Property Mapping Evaluator"""
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from prometheus_client import Histogram
|
|
||||||
|
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
@ -11,12 +10,6 @@ from authentik.lib.expression.evaluator import BaseEvaluator
|
|||||||
from authentik.lib.utils.errors import exception_to_string
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
from authentik.policies.types import PolicyRequest
|
from authentik.policies.types import PolicyRequest
|
||||||
|
|
||||||
PROPERTY_MAPPING_TIME = Histogram(
|
|
||||||
"authentik_property_mapping_execution_time",
|
|
||||||
"Evaluation time of property mappings",
|
|
||||||
["mapping_name"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PropertyMappingEvaluator(BaseEvaluator):
|
class PropertyMappingEvaluator(BaseEvaluator):
|
||||||
"""Custom Evaluator that adds some different context variables."""
|
"""Custom Evaluator that adds some different context variables."""
|
||||||
@ -56,7 +49,3 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
|||||||
event.from_http(req.http_request, req.user)
|
event.from_http(req.http_request, req.user)
|
||||||
return
|
return
|
||||||
event.save()
|
event.save()
|
||||||
|
|
||||||
def evaluate(self, *args, **kwargs) -> Any:
|
|
||||||
with PROPERTY_MAPPING_TIME.labels(mapping_name=self._filename).time():
|
|
||||||
return super().evaluate(*args, **kwargs)
|
|
||||||
|
@ -18,13 +18,13 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
|||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
|
|
||||||
akadmin, _ = User.objects.using(db_alias).get_or_create(
|
akadmin, _ = User.objects.using(db_alias).get_or_create(
|
||||||
username="akadmin",
|
username="akadmin", email="root@localhost", name="authentik Default Admin"
|
||||||
email=environ.get("AUTHENTIK_BOOTSTRAP_EMAIL", "root@localhost"),
|
|
||||||
name="authentik Default Admin",
|
|
||||||
)
|
)
|
||||||
password = None
|
password = None
|
||||||
if "TF_BUILD" in environ or settings.TEST:
|
if "TF_BUILD" in environ or settings.TEST:
|
||||||
password = "akadmin" # noqa # nosec
|
password = "akadmin" # noqa # nosec
|
||||||
|
if "AK_ADMIN_PASS" in environ:
|
||||||
|
password = environ["AK_ADMIN_PASS"]
|
||||||
if "AUTHENTIK_BOOTSTRAP_PASSWORD" in environ:
|
if "AUTHENTIK_BOOTSTRAP_PASSWORD" in environ:
|
||||||
password = environ["AUTHENTIK_BOOTSTRAP_PASSWORD"]
|
password = environ["AUTHENTIK_BOOTSTRAP_PASSWORD"]
|
||||||
if password:
|
if password:
|
||||||
|
@ -46,9 +46,13 @@ def create_default_user_token(apps: Apps, schema_editor: BaseDatabaseSchemaEdito
|
|||||||
akadmin = User.objects.using(db_alias).filter(username="akadmin")
|
akadmin = User.objects.using(db_alias).filter(username="akadmin")
|
||||||
if not akadmin.exists():
|
if not akadmin.exists():
|
||||||
return
|
return
|
||||||
if "AUTHENTIK_BOOTSTRAP_TOKEN" not in environ:
|
key = None
|
||||||
return
|
if "AK_ADMIN_TOKEN" in environ:
|
||||||
|
key = environ["AK_ADMIN_TOKEN"]
|
||||||
|
if "AUTHENTIK_BOOTSTRAP_TOKEN" in environ:
|
||||||
key = environ["AUTHENTIK_BOOTSTRAP_TOKEN"]
|
key = environ["AUTHENTIK_BOOTSTRAP_TOKEN"]
|
||||||
|
if not key:
|
||||||
|
return
|
||||||
Token.objects.using(db_alias).create(
|
Token.objects.using(db_alias).create(
|
||||||
identifier="authentik-bootstrap-token",
|
identifier="authentik-bootstrap-token",
|
||||||
user=akadmin.first(),
|
user=akadmin.first(),
|
||||||
@ -182,9 +186,7 @@ class Migration(migrations.Migration):
|
|||||||
model_name="application",
|
model_name="application",
|
||||||
name="meta_launch_url",
|
name="meta_launch_url",
|
||||||
field=models.TextField(
|
field=models.TextField(
|
||||||
blank=True,
|
blank=True, default="", validators=[authentik.lib.models.DomainlessURLValidator()]
|
||||||
default="",
|
|
||||||
validators=[authentik.lib.models.DomainlessFormattedURLValidator()],
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
migrations.RunPython(
|
migrations.RunPython(
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
# Generated by Django 4.1.7 on 2023-03-02 21:32
|
|
||||||
|
|
||||||
import django.db.models.deletion
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("authentik_flows", "0025_alter_flowstagebinding_evaluate_on_plan_and_more"),
|
|
||||||
("authentik_core", "0024_source_icon"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="provider",
|
|
||||||
name="authorization_flow",
|
|
||||||
field=models.ForeignKey(
|
|
||||||
help_text="Flow used when authorizing this provider.",
|
|
||||||
null=True,
|
|
||||||
on_delete=django.db.models.deletion.CASCADE,
|
|
||||||
related_name="provider_authorization",
|
|
||||||
to="authentik_flows.flow",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
@ -1,26 +0,0 @@
|
|||||||
# Generated by Django 4.1.7 on 2023-03-07 13:41
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
from authentik.lib.migrations import fallback_names
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("authentik_core", "0025_alter_provider_authorization_flow"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.RunPython(fallback_names("authentik_core", "propertymapping", "name")),
|
|
||||||
migrations.RunPython(fallback_names("authentik_core", "provider", "name")),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="propertymapping",
|
|
||||||
name="name",
|
|
||||||
field=models.TextField(unique=True),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="provider",
|
|
||||||
name="name",
|
|
||||||
field=models.TextField(unique=True),
|
|
||||||
),
|
|
||||||
]
|
|
@ -22,15 +22,12 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.blueprints.models import ManagedModel
|
from authentik.blueprints.models import ManagedModel
|
||||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||||
|
from authentik.core.signals import password_changed
|
||||||
from authentik.core.types import UILoginButton, UserSettingSerializer
|
from authentik.core.types import UILoginButton, UserSettingSerializer
|
||||||
from authentik.lib.avatars import get_avatar
|
from authentik.lib.avatars import get_avatar
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.models import (
|
from authentik.lib.models import CreatedUpdatedModel, DomainlessURLValidator, SerializerModel
|
||||||
CreatedUpdatedModel,
|
|
||||||
DomainlessFormattedURLValidator,
|
|
||||||
SerializerModel,
|
|
||||||
)
|
|
||||||
from authentik.lib.utils.http import get_client_ip
|
from authentik.lib.utils.http import get_client_ip
|
||||||
from authentik.policies.models import PolicyBindingModel
|
from authentik.policies.models import PolicyBindingModel
|
||||||
|
|
||||||
@ -192,8 +189,6 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
|||||||
|
|
||||||
def set_password(self, raw_password, signal=True):
|
def set_password(self, raw_password, signal=True):
|
||||||
if self.pk and signal:
|
if self.pk and signal:
|
||||||
from authentik.core.signals import password_changed
|
|
||||||
|
|
||||||
password_changed.send(sender=self, user=self, password=raw_password)
|
password_changed.send(sender=self, user=self, password=raw_password)
|
||||||
self.password_change_date = now()
|
self.password_change_date = now()
|
||||||
return super().set_password(raw_password)
|
return super().set_password(raw_password)
|
||||||
@ -247,12 +242,11 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
|||||||
class Provider(SerializerModel):
|
class Provider(SerializerModel):
|
||||||
"""Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application"""
|
"""Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application"""
|
||||||
|
|
||||||
name = models.TextField(unique=True)
|
name = models.TextField()
|
||||||
|
|
||||||
authorization_flow = models.ForeignKey(
|
authorization_flow = models.ForeignKey(
|
||||||
"authentik_flows.Flow",
|
"authentik_flows.Flow",
|
||||||
on_delete=models.CASCADE,
|
on_delete=models.CASCADE,
|
||||||
null=True,
|
|
||||||
help_text=_("Flow used when authorizing this provider."),
|
help_text=_("Flow used when authorizing this provider."),
|
||||||
related_name="provider_authorization",
|
related_name="provider_authorization",
|
||||||
)
|
)
|
||||||
@ -295,7 +289,7 @@ class Application(SerializerModel, PolicyBindingModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
meta_launch_url = models.TextField(
|
meta_launch_url = models.TextField(
|
||||||
default="", blank=True, validators=[DomainlessFormattedURLValidator()]
|
default="", blank=True, validators=[DomainlessURLValidator()]
|
||||||
)
|
)
|
||||||
|
|
||||||
open_in_new_tab = models.BooleanField(
|
open_in_new_tab = models.BooleanField(
|
||||||
@ -612,7 +606,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
|||||||
"""User-defined key -> x mapping which can be used by providers to expose extra data."""
|
"""User-defined key -> x mapping which can be used by providers to expose extra data."""
|
||||||
|
|
||||||
pm_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
pm_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
name = models.TextField(unique=True)
|
name = models.TextField()
|
||||||
expression = models.TextField()
|
expression = models.TextField()
|
||||||
|
|
||||||
objects = InheritanceManager()
|
objects = InheritanceManager()
|
||||||
@ -635,7 +629,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
|||||||
try:
|
try:
|
||||||
return evaluator.evaluate(self.expression)
|
return evaluator.evaluate(self.expression)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise PropertyMappingExpressionException(exc) from exc
|
raise PropertyMappingExpressionException(str(exc)) from exc
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Property Mapping {self.name}"
|
return f"Property Mapping {self.name}"
|
||||||
|
@ -10,25 +10,25 @@ from django.db.models.signals import post_save, pre_delete
|
|||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from django.http.request import HttpRequest
|
from django.http.request import HttpRequest
|
||||||
|
|
||||||
from authentik.core.models import Application, AuthenticatedSession
|
|
||||||
|
|
||||||
# Arguments: user: User, password: str
|
# Arguments: user: User, password: str
|
||||||
password_changed = Signal()
|
password_changed = Signal()
|
||||||
# Arguments: credentials: dict[str, any], request: HttpRequest, stage: Stage
|
# Arguments: credentials: dict[str, any], request: HttpRequest, stage: Stage
|
||||||
login_failed = Signal()
|
login_failed = Signal()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.core.models import User
|
from authentik.core.models import AuthenticatedSession, User
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_save, sender=Application)
|
@receiver(post_save)
|
||||||
def post_save_application(sender: type[Model], instance, created: bool, **_):
|
def post_save_application(sender: type[Model], instance, created: bool, **_):
|
||||||
"""Clear user's application cache upon application creation"""
|
"""Clear user's application cache upon application creation"""
|
||||||
from authentik.core.api.applications import user_app_cache_key
|
from authentik.core.api.applications import user_app_cache_key
|
||||||
|
from authentik.core.models import Application
|
||||||
|
|
||||||
|
if sender != Application:
|
||||||
|
return
|
||||||
if not created: # pragma: no cover
|
if not created: # pragma: no cover
|
||||||
return
|
return
|
||||||
|
|
||||||
# Also delete user application cache
|
# Also delete user application cache
|
||||||
keys = cache.keys(user_app_cache_key("*"))
|
keys = cache.keys(user_app_cache_key("*"))
|
||||||
cache.delete_many(keys)
|
cache.delete_many(keys)
|
||||||
@ -37,6 +37,7 @@ def post_save_application(sender: type[Model], instance, created: bool, **_):
|
|||||||
@receiver(user_logged_in)
|
@receiver(user_logged_in)
|
||||||
def user_logged_in_session(sender, request: HttpRequest, user: "User", **_):
|
def user_logged_in_session(sender, request: HttpRequest, user: "User", **_):
|
||||||
"""Create an AuthenticatedSession from request"""
|
"""Create an AuthenticatedSession from request"""
|
||||||
|
from authentik.core.models import AuthenticatedSession
|
||||||
|
|
||||||
session = AuthenticatedSession.from_request(request, user)
|
session = AuthenticatedSession.from_request(request, user)
|
||||||
if session:
|
if session:
|
||||||
@ -46,11 +47,18 @@ def user_logged_in_session(sender, request: HttpRequest, user: "User", **_):
|
|||||||
@receiver(user_logged_out)
|
@receiver(user_logged_out)
|
||||||
def user_logged_out_session(sender, request: HttpRequest, user: "User", **_):
|
def user_logged_out_session(sender, request: HttpRequest, user: "User", **_):
|
||||||
"""Delete AuthenticatedSession if it exists"""
|
"""Delete AuthenticatedSession if it exists"""
|
||||||
|
from authentik.core.models import AuthenticatedSession
|
||||||
|
|
||||||
AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()
|
AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()
|
||||||
|
|
||||||
|
|
||||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
@receiver(pre_delete)
|
||||||
def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_):
|
def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_):
|
||||||
"""Delete session when authenticated session is deleted"""
|
"""Delete session when authenticated session is deleted"""
|
||||||
|
from authentik.core.models import AuthenticatedSession
|
||||||
|
|
||||||
|
if sender != AuthenticatedSession:
|
||||||
|
return
|
||||||
|
|
||||||
cache_key = f"{KEY_PREFIX}{instance.session_key}"
|
cache_key = f"{KEY_PREFIX}{instance.session_key}"
|
||||||
cache.delete(cache_key)
|
cache.delete(cache_key)
|
||||||
|
@ -16,8 +16,7 @@
|
|||||||
{% block head_before %}
|
{% block head_before %}
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
<link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}">
|
<link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}">
|
||||||
<link rel="stylesheet" type="text/css" href="{% static 'dist/theme-dark.css' %}" media="(prefers-color-scheme: dark)">
|
<link rel="stylesheet" type="text/css" href="{% static 'dist/custom.css' %}">
|
||||||
<link rel="stylesheet" type="text/css" href="{% static 'dist/custom.css' %}" data-inject>
|
|
||||||
<script src="{% static 'dist/poly.js' %}" type="module"></script>
|
<script src="{% static 'dist/poly.js' %}" type="module"></script>
|
||||||
{% block head %}
|
{% block head %}
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
@ -37,22 +37,6 @@ class TestApplicationsAPI(APITestCase):
|
|||||||
order=0,
|
order=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_formatted_launch_url(self):
|
|
||||||
"""Test formatted launch URL"""
|
|
||||||
self.client.force_login(self.user)
|
|
||||||
self.assertEqual(
|
|
||||||
self.client.patch(
|
|
||||||
reverse("authentik_api:application-detail", kwargs={"slug": self.allowed.slug}),
|
|
||||||
{"meta_launch_url": "https://%(username)s-test.test.goauthentik.io/%(username)s"},
|
|
||||||
).status_code,
|
|
||||||
200,
|
|
||||||
)
|
|
||||||
self.allowed.refresh_from_db()
|
|
||||||
self.assertEqual(
|
|
||||||
self.allowed.get_launch_url(self.user),
|
|
||||||
f"https://{self.user.username}-test.test.goauthentik.io/{self.user.username}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_set_icon(self):
|
def test_set_icon(self):
|
||||||
"""Test set_icon"""
|
"""Test set_icon"""
|
||||||
file = ContentFile(b"text", "name")
|
file = ContentFile(b"text", "name")
|
||||||
|
@ -5,7 +5,6 @@ from django.urls.base import reverse
|
|||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from authentik.core.api.tokens import TokenSerializer
|
|
||||||
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User
|
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
@ -100,16 +99,3 @@ class TestTokenAPI(APITestCase):
|
|||||||
self.assertEqual(len(body["results"]), 2)
|
self.assertEqual(len(body["results"]), 2)
|
||||||
self.assertEqual(body["results"][0]["identifier"], token_should.identifier)
|
self.assertEqual(body["results"][0]["identifier"], token_should.identifier)
|
||||||
self.assertEqual(body["results"][1]["identifier"], token_should_not.identifier)
|
self.assertEqual(body["results"][1]["identifier"], token_should_not.identifier)
|
||||||
|
|
||||||
def test_serializer_no_request(self):
|
|
||||||
"""Test serializer without request"""
|
|
||||||
self.assertTrue(
|
|
||||||
TokenSerializer(
|
|
||||||
data={
|
|
||||||
"identifier": generate_id(),
|
|
||||||
"intent": TokenIntents.INTENT_APP_PASSWORD,
|
|
||||||
"key": generate_id(),
|
|
||||||
"user": self.user.pk,
|
|
||||||
}
|
|
||||||
).is_valid(raise_exception=True)
|
|
||||||
)
|
|
||||||
|
@ -1,19 +1,11 @@
|
|||||||
"""Test Users API"""
|
"""Test Users API"""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.urls.base import reverse
|
from django.urls.base import reverse
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from authentik.core.models import (
|
from authentik.core.models import AuthenticatedSession, User
|
||||||
USER_ATTRIBUTE_SA,
|
|
||||||
USER_ATTRIBUTE_TOKEN_EXPIRING,
|
|
||||||
AuthenticatedSession,
|
|
||||||
Token,
|
|
||||||
User,
|
|
||||||
)
|
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant
|
||||||
from authentik.flows.models import FlowDesignation
|
from authentik.flows.models import FlowDesignation
|
||||||
from authentik.lib.generators import generate_id, generate_key
|
from authentik.lib.generators import generate_id, generate_key
|
||||||
@ -138,71 +130,7 @@ class TestUsersAPI(APITestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue(User.objects.filter(username="test-sa").exists())
|
||||||
user_filter = User.objects.filter(
|
|
||||||
username="test-sa",
|
|
||||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True, USER_ATTRIBUTE_SA: True},
|
|
||||||
)
|
|
||||||
self.assertTrue(user_filter.exists())
|
|
||||||
user: User = user_filter.first()
|
|
||||||
self.assertFalse(user.has_usable_password())
|
|
||||||
|
|
||||||
token_filter = Token.objects.filter(user=user)
|
|
||||||
self.assertTrue(token_filter.exists())
|
|
||||||
self.assertTrue(token_filter.first().expiring)
|
|
||||||
|
|
||||||
def test_service_account_no_expire(self):
|
|
||||||
"""Service account creation without token expiration"""
|
|
||||||
self.client.force_login(self.admin)
|
|
||||||
response = self.client.post(
|
|
||||||
reverse("authentik_api:user-service-account"),
|
|
||||||
data={
|
|
||||||
"name": "test-sa",
|
|
||||||
"create_group": True,
|
|
||||||
"expiring": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
user_filter = User.objects.filter(
|
|
||||||
username="test-sa",
|
|
||||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: False, USER_ATTRIBUTE_SA: True},
|
|
||||||
)
|
|
||||||
self.assertTrue(user_filter.exists())
|
|
||||||
user: User = user_filter.first()
|
|
||||||
self.assertFalse(user.has_usable_password())
|
|
||||||
|
|
||||||
token_filter = Token.objects.filter(user=user)
|
|
||||||
self.assertTrue(token_filter.exists())
|
|
||||||
self.assertFalse(token_filter.first().expiring)
|
|
||||||
|
|
||||||
def test_service_account_with_custom_expire(self):
|
|
||||||
"""Service account creation with custom token expiration date"""
|
|
||||||
self.client.force_login(self.admin)
|
|
||||||
expire_on = datetime(2050, 11, 11, 11, 11, 11).astimezone()
|
|
||||||
response = self.client.post(
|
|
||||||
reverse("authentik_api:user-service-account"),
|
|
||||||
data={
|
|
||||||
"name": "test-sa",
|
|
||||||
"create_group": True,
|
|
||||||
"expires": expire_on.isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
user_filter = User.objects.filter(
|
|
||||||
username="test-sa",
|
|
||||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True, USER_ATTRIBUTE_SA: True},
|
|
||||||
)
|
|
||||||
self.assertTrue(user_filter.exists())
|
|
||||||
user: User = user_filter.first()
|
|
||||||
self.assertFalse(user.has_usable_password())
|
|
||||||
|
|
||||||
token_filter = Token.objects.filter(user=user)
|
|
||||||
self.assertTrue(token_filter.exists())
|
|
||||||
token = token_filter.first()
|
|
||||||
self.assertTrue(token.expiring)
|
|
||||||
self.assertEqual(token.expires, expire_on)
|
|
||||||
|
|
||||||
def test_service_account_invalid(self):
|
def test_service_account_invalid(self):
|
||||||
"""Service account creation (twice with same name, expect error)"""
|
"""Service account creation (twice with same name, expect error)"""
|
||||||
@ -215,19 +143,7 @@ class TestUsersAPI(APITestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue(User.objects.filter(username="test-sa").exists())
|
||||||
user_filter = User.objects.filter(
|
|
||||||
username="test-sa",
|
|
||||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True, USER_ATTRIBUTE_SA: True},
|
|
||||||
)
|
|
||||||
self.assertTrue(user_filter.exists())
|
|
||||||
user: User = user_filter.first()
|
|
||||||
self.assertFalse(user.has_usable_password())
|
|
||||||
|
|
||||||
token_filter = Token.objects.filter(user=user)
|
|
||||||
self.assertTrue(token_filter.exists())
|
|
||||||
self.assertTrue(token_filter.first().expiring)
|
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse("authentik_api:user-service-account"),
|
reverse("authentik_api:user-service-account"),
|
||||||
data={
|
data={
|
||||||
|
@ -11,7 +11,6 @@ from authentik.flows.challenge import (
|
|||||||
HttpChallengeResponse,
|
HttpChallengeResponse,
|
||||||
RedirectChallenge,
|
RedirectChallenge,
|
||||||
)
|
)
|
||||||
from authentik.flows.exceptions import FlowNonApplicableException
|
|
||||||
from authentik.flows.models import in_memory_stage
|
from authentik.flows.models import in_memory_stage
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner
|
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner
|
||||||
from authentik.flows.stage import ChallengeStageView
|
from authentik.flows.stage import ChallengeStageView
|
||||||
@ -42,7 +41,6 @@ class RedirectToAppLaunch(View):
|
|||||||
flow = tenant.flow_authentication
|
flow = tenant.flow_authentication
|
||||||
planner = FlowPlanner(flow)
|
planner = FlowPlanner(flow)
|
||||||
planner.allow_empty_flows = True
|
planner.allow_empty_flows = True
|
||||||
try:
|
|
||||||
plan = planner.plan(
|
plan = planner.plan(
|
||||||
request,
|
request,
|
||||||
{
|
{
|
||||||
@ -52,8 +50,6 @@ class RedirectToAppLaunch(View):
|
|||||||
PLAN_CONTEXT_CONSENT_PERMISSIONS: [],
|
PLAN_CONTEXT_CONSENT_PERMISSIONS: [],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except FlowNonApplicableException:
|
|
||||||
raise Http404
|
|
||||||
plan.insert_stage(in_memory_stage(RedirectToAppStage))
|
plan.insert_stage(in_memory_stage(RedirectToAppStage))
|
||||||
request.session[SESSION_KEY_PLAN] = plan
|
request.session[SESSION_KEY_PLAN] = plan
|
||||||
return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug)
|
return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug)
|
||||||
|
@ -7,14 +7,13 @@ from django.conf import settings
|
|||||||
from django.contrib.sessions.models import Session
|
from django.contrib.sessions.models import Session
|
||||||
from django.core.exceptions import SuspiciousOperation
|
from django.core.exceptions import SuspiciousOperation
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
from django.db.models.signals import m2m_changed, post_save, pre_delete
|
from django.db.models.signals import post_save, pre_delete
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django_otp.plugins.otp_static.models import StaticToken
|
from django_otp.plugins.otp_static.models import StaticToken
|
||||||
from guardian.models import UserObjectPermission
|
from guardian.models import UserObjectPermission
|
||||||
|
|
||||||
from authentik.core.models import (
|
from authentik.core.models import (
|
||||||
AuthenticatedSession,
|
AuthenticatedSession,
|
||||||
Group,
|
|
||||||
PropertyMapping,
|
PropertyMapping,
|
||||||
Provider,
|
Provider,
|
||||||
Source,
|
Source,
|
||||||
@ -29,7 +28,6 @@ from authentik.lib.utils.errors import exception_to_string
|
|||||||
from authentik.outposts.models import OutpostServiceConnection
|
from authentik.outposts.models import OutpostServiceConnection
|
||||||
from authentik.policies.models import Policy, PolicyBindingModel
|
from authentik.policies.models import Policy, PolicyBindingModel
|
||||||
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
|
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
|
||||||
from authentik.providers.scim.models import SCIMGroup, SCIMUser
|
|
||||||
|
|
||||||
IGNORED_MODELS = (
|
IGNORED_MODELS = (
|
||||||
Event,
|
Event,
|
||||||
@ -50,8 +48,6 @@ IGNORED_MODELS = (
|
|||||||
AuthorizationCode,
|
AuthorizationCode,
|
||||||
AccessToken,
|
AccessToken,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
SCIMUser,
|
|
||||||
SCIMGroup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -62,13 +58,6 @@ def should_log_model(model: Model) -> bool:
|
|||||||
return model.__class__ not in IGNORED_MODELS
|
return model.__class__ not in IGNORED_MODELS
|
||||||
|
|
||||||
|
|
||||||
def should_log_m2m(model: Model) -> bool:
|
|
||||||
"""Return true if m2m operation should be logged"""
|
|
||||||
if model.__class__ in [User, Group]:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class EventNewThread(Thread):
|
class EventNewThread(Thread):
|
||||||
"""Create Event in background thread"""
|
"""Create Event in background thread"""
|
||||||
|
|
||||||
@ -107,7 +96,6 @@ class AuditMiddleware:
|
|||||||
return
|
return
|
||||||
post_save_handler = partial(self.post_save_handler, user=request.user, request=request)
|
post_save_handler = partial(self.post_save_handler, user=request.user, request=request)
|
||||||
pre_delete_handler = partial(self.pre_delete_handler, user=request.user, request=request)
|
pre_delete_handler = partial(self.pre_delete_handler, user=request.user, request=request)
|
||||||
m2m_changed_handler = partial(self.m2m_changed_handler, user=request.user, request=request)
|
|
||||||
post_save.connect(
|
post_save.connect(
|
||||||
post_save_handler,
|
post_save_handler,
|
||||||
dispatch_uid=request.request_id,
|
dispatch_uid=request.request_id,
|
||||||
@ -118,11 +106,6 @@ class AuditMiddleware:
|
|||||||
dispatch_uid=request.request_id,
|
dispatch_uid=request.request_id,
|
||||||
weak=False,
|
weak=False,
|
||||||
)
|
)
|
||||||
m2m_changed.connect(
|
|
||||||
m2m_changed_handler,
|
|
||||||
dispatch_uid=request.request_id,
|
|
||||||
weak=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def disconnect(self, request: HttpRequest):
|
def disconnect(self, request: HttpRequest):
|
||||||
"""Disconnect signals"""
|
"""Disconnect signals"""
|
||||||
@ -130,7 +113,6 @@ class AuditMiddleware:
|
|||||||
return
|
return
|
||||||
post_save.disconnect(dispatch_uid=request.request_id)
|
post_save.disconnect(dispatch_uid=request.request_id)
|
||||||
pre_delete.disconnect(dispatch_uid=request.request_id)
|
pre_delete.disconnect(dispatch_uid=request.request_id)
|
||||||
m2m_changed.disconnect(dispatch_uid=request.request_id)
|
|
||||||
|
|
||||||
def __call__(self, request: HttpRequest) -> HttpResponse:
|
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||||
self.connect(request)
|
self.connect(request)
|
||||||
@ -185,20 +167,3 @@ class AuditMiddleware:
|
|||||||
user=user,
|
user=user,
|
||||||
model=model_to_dict(instance),
|
model=model_to_dict(instance),
|
||||||
).run()
|
).run()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def m2m_changed_handler(
|
|
||||||
user: User, request: HttpRequest, sender, instance: Model, action: str, **_
|
|
||||||
):
|
|
||||||
"""Signal handler for all object's m2m_changed"""
|
|
||||||
if action not in ["pre_add", "pre_remove", "post_clear"]:
|
|
||||||
return
|
|
||||||
if not should_log_m2m(instance):
|
|
||||||
return
|
|
||||||
|
|
||||||
EventNewThread(
|
|
||||||
EventAction.MODEL_UPDATED,
|
|
||||||
request,
|
|
||||||
user=user,
|
|
||||||
model=model_to_dict(instance),
|
|
||||||
).run()
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
import django.db.models.deletion
|
import django.db.models.deletion
|
||||||
from django.apps.registry import Apps
|
from django.apps.registry import Apps
|
||||||
@ -12,7 +13,6 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
|||||||
import authentik.events.models
|
import authentik.events.models
|
||||||
import authentik.lib.models
|
import authentik.lib.models
|
||||||
from authentik.events.models import EventAction, NotificationSeverity, TransportMode
|
from authentik.events.models import EventAction, NotificationSeverity, TransportMode
|
||||||
from authentik.lib.migrations import progress_bar
|
|
||||||
|
|
||||||
|
|
||||||
def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
@ -43,6 +43,49 @@ def token_view_to_secret_view(apps: Apps, schema_editor: BaseDatabaseSchemaEdito
|
|||||||
Event.objects.using(db_alias).bulk_update(events, ["context", "action"])
|
Event.objects.using(db_alias).bulk_update(events, ["context", "action"])
|
||||||
|
|
||||||
|
|
||||||
|
# Taken from https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
|
||||||
|
def progress_bar(
|
||||||
|
iterable: Iterable,
|
||||||
|
prefix="Writing: ",
|
||||||
|
suffix=" finished",
|
||||||
|
decimals=1,
|
||||||
|
length=100,
|
||||||
|
fill="█",
|
||||||
|
print_end="\r",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Call in a loop to create terminal progress bar
|
||||||
|
@params:
|
||||||
|
iteration - Required : current iteration (Int)
|
||||||
|
total - Required : total iterations (Int)
|
||||||
|
prefix - Optional : prefix string (Str)
|
||||||
|
suffix - Optional : suffix string (Str)
|
||||||
|
decimals - Optional : positive number of decimals in percent complete (Int)
|
||||||
|
length - Optional : character length of bar (Int)
|
||||||
|
fill - Optional : bar fill character (Str)
|
||||||
|
print_end - Optional : end character (e.g. "\r", "\r\n") (Str)
|
||||||
|
"""
|
||||||
|
total = len(iterable)
|
||||||
|
if total < 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
def print_progress_bar(iteration):
|
||||||
|
"""Progress Bar Printing Function"""
|
||||||
|
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
|
||||||
|
filledLength = int(length * iteration // total)
|
||||||
|
bar = fill * filledLength + "-" * (length - filledLength)
|
||||||
|
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
|
||||||
|
|
||||||
|
# Initial Call
|
||||||
|
print_progress_bar(0)
|
||||||
|
# Update Progress Bar
|
||||||
|
for i, item in enumerate(iterable):
|
||||||
|
yield item
|
||||||
|
print_progress_bar(i + 1)
|
||||||
|
# Print New Line on Complete
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
def update_expires(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
def update_expires(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
Event = apps.get_model("authentik_events", "event")
|
Event = apps.get_model("authentik_events", "event")
|
||||||
|
@ -41,7 +41,7 @@ class TaskResult:
|
|||||||
|
|
||||||
def with_error(self, exc: Exception) -> "TaskResult":
|
def with_error(self, exc: Exception) -> "TaskResult":
|
||||||
"""Since errors might not always be pickle-able, set the traceback"""
|
"""Since errors might not always be pickle-able, set the traceback"""
|
||||||
self.messages.append(exception_to_string(exc))
|
self.messages.append(str(exc))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +111,6 @@ class MonitoredTask(Task):
|
|||||||
_result: Optional[TaskResult]
|
_result: Optional[TaskResult]
|
||||||
|
|
||||||
_uid: Optional[str]
|
_uid: Optional[str]
|
||||||
start: Optional[float] = None
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -119,6 +118,7 @@ class MonitoredTask(Task):
|
|||||||
self._uid = None
|
self._uid = None
|
||||||
self._result = None
|
self._result = None
|
||||||
self.result_timeout_hours = 6
|
self.result_timeout_hours = 6
|
||||||
|
self.start = default_timer()
|
||||||
|
|
||||||
def set_uid(self, uid: str):
|
def set_uid(self, uid: str):
|
||||||
"""Set UID, so in the case of an unexpected error its saved correctly"""
|
"""Set UID, so in the case of an unexpected error its saved correctly"""
|
||||||
@ -128,10 +128,6 @@ class MonitoredTask(Task):
|
|||||||
"""Set result for current run, will overwrite previous result."""
|
"""Set result for current run, will overwrite previous result."""
|
||||||
self._result = result
|
self._result = result
|
||||||
|
|
||||||
def before_start(self, task_id, args, kwargs):
|
|
||||||
self.start = default_timer()
|
|
||||||
return super().before_start(task_id, args, kwargs)
|
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
|
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
|
||||||
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
|
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
|
||||||
@ -142,7 +138,7 @@ class MonitoredTask(Task):
|
|||||||
info = TaskInfo(
|
info = TaskInfo(
|
||||||
task_name=self.__name__,
|
task_name=self.__name__,
|
||||||
task_description=self.__doc__,
|
task_description=self.__doc__,
|
||||||
start_timestamp=self.start or default_timer(),
|
start_timestamp=self.start,
|
||||||
finish_timestamp=default_timer(),
|
finish_timestamp=default_timer(),
|
||||||
finish_time=datetime.now(),
|
finish_time=datetime.now(),
|
||||||
result=self._result,
|
result=self._result,
|
||||||
@ -166,7 +162,7 @@ class MonitoredTask(Task):
|
|||||||
TaskInfo(
|
TaskInfo(
|
||||||
task_name=self.__name__,
|
task_name=self.__name__,
|
||||||
task_description=self.__doc__,
|
task_description=self.__doc__,
|
||||||
start_timestamp=self.start or default_timer(),
|
start_timestamp=self.start,
|
||||||
finish_timestamp=default_timer(),
|
finish_timestamp=default_timer(),
|
||||||
finish_time=datetime.now(),
|
finish_time=datetime.now(),
|
||||||
result=self._result,
|
result=self._result,
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
"""Flow Binding API Views"""
|
"""Flow Binding API Views"""
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from rest_framework.exceptions import ValidationError
|
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
@ -15,13 +12,6 @@ class FlowStageBindingSerializer(ModelSerializer):
|
|||||||
|
|
||||||
stage_obj = StageSerializer(read_only=True, source="stage")
|
stage_obj = StageSerializer(read_only=True, source="stage")
|
||||||
|
|
||||||
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
evaluate_on_plan = attrs.get("evaluate_on_plan", False)
|
|
||||||
re_evaluate_policies = attrs.get("re_evaluate_policies", True)
|
|
||||||
if not evaluate_on_plan and not re_evaluate_policies:
|
|
||||||
raise ValidationError("Either evaluation on plan or evaluation on run must be enabled")
|
|
||||||
return super().validate(attrs)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = FlowStageBinding
|
model = FlowStageBinding
|
||||||
fields = [
|
fields = [
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
# Generated by Django 4.1.7 on 2023-02-25 15:51
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("authentik_flows", "0024_flow_authentication"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="flowstagebinding",
|
|
||||||
name="evaluate_on_plan",
|
|
||||||
field=models.BooleanField(
|
|
||||||
default=False, help_text="Evaluate policies during the Flow planning process."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="flowstagebinding",
|
|
||||||
name="re_evaluate_policies",
|
|
||||||
field=models.BooleanField(
|
|
||||||
default=True, help_text="Evaluate policies when the Stage is present to the user."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
@ -211,11 +211,14 @@ class FlowStageBinding(SerializerModel, PolicyBindingModel):
|
|||||||
stage = InheritanceForeignKey(Stage, on_delete=models.CASCADE)
|
stage = InheritanceForeignKey(Stage, on_delete=models.CASCADE)
|
||||||
|
|
||||||
evaluate_on_plan = models.BooleanField(
|
evaluate_on_plan = models.BooleanField(
|
||||||
default=False,
|
default=True,
|
||||||
help_text=_("Evaluate policies during the Flow planning process."),
|
help_text=_(
|
||||||
|
"Evaluate policies during the Flow planning process. "
|
||||||
|
"Disable this for input-based policies."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
re_evaluate_policies = models.BooleanField(
|
re_evaluate_policies = models.BooleanField(
|
||||||
default=True,
|
default=False,
|
||||||
help_text=_("Evaluate policies when the Stage is present to the user."),
|
help_text=_("Evaluate policies when the Stage is present to the user."),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -147,6 +147,7 @@ class FlowPlanner:
|
|||||||
) -> FlowPlan:
|
) -> FlowPlan:
|
||||||
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
|
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
|
||||||
and return ordered list"""
|
and return ordered list"""
|
||||||
|
self._check_authentication(request)
|
||||||
with Hub.current.start_span(
|
with Hub.current.start_span(
|
||||||
op="authentik.flow.planner.plan", description=self.flow.slug
|
op="authentik.flow.planner.plan", description=self.flow.slug
|
||||||
) as span:
|
) as span:
|
||||||
@ -164,12 +165,6 @@ class FlowPlanner:
|
|||||||
user = default_context[PLAN_CONTEXT_PENDING_USER]
|
user = default_context[PLAN_CONTEXT_PENDING_USER]
|
||||||
else:
|
else:
|
||||||
user = request.user
|
user = request.user
|
||||||
# We only need to check the flow authentication if it's planned without a user
|
|
||||||
# in the context, as a user in the context can only be set via the explicit code API
|
|
||||||
# or if a flow is restarted due to `invalid_response_action` being set to
|
|
||||||
# `restart_with_context`, which can only happen if the user was already authorized
|
|
||||||
# to use the flow
|
|
||||||
self._check_authentication(request)
|
|
||||||
# First off, check the flow's direct policy bindings
|
# First off, check the flow's direct policy bindings
|
||||||
# to make sure the user even has access to the flow
|
# to make sure the user even has access to the flow
|
||||||
engine = PolicyEngine(self.flow, user, request)
|
engine = PolicyEngine(self.flow, user, request)
|
||||||
@ -266,6 +261,7 @@ class FlowPlanner:
|
|||||||
marker = ReevaluateMarker(binding=binding)
|
marker = ReevaluateMarker(binding=binding)
|
||||||
if stage:
|
if stage:
|
||||||
plan.append(binding, marker)
|
plan.append(binding, marker)
|
||||||
|
HIST_FLOWS_PLAN_TIME.labels(flow_slug=self.flow.slug)
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
"f(plan): finished building",
|
"f(plan): finished building",
|
||||||
)
|
)
|
||||||
|
@ -7,7 +7,6 @@ from django.http.request import QueryDict
|
|||||||
from django.http.response import HttpResponse
|
from django.http.response import HttpResponse
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.views.generic.base import View
|
from django.views.generic.base import View
|
||||||
from prometheus_client import Histogram
|
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from sentry_sdk.hub import Hub
|
from sentry_sdk.hub import Hub
|
||||||
from structlog.stdlib import BoundLogger, get_logger
|
from structlog.stdlib import BoundLogger, get_logger
|
||||||
@ -32,11 +31,6 @@ if TYPE_CHECKING:
|
|||||||
from authentik.flows.views.executor import FlowExecutorView
|
from authentik.flows.views.executor import FlowExecutorView
|
||||||
|
|
||||||
PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier"
|
PLAN_CONTEXT_PENDING_USER_IDENTIFIER = "pending_user_identifier"
|
||||||
HIST_FLOWS_STAGE_TIME = Histogram(
|
|
||||||
"authentik_flows_stage_time",
|
|
||||||
"Duration taken by different parts of stages",
|
|
||||||
["stage_type", "method"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StageView(View):
|
class StageView(View):
|
||||||
@ -115,24 +109,14 @@ class ChallengeStageView(StageView):
|
|||||||
keep_context=keep_context,
|
keep_context=keep_context,
|
||||||
)
|
)
|
||||||
return self.executor.restart_flow(keep_context)
|
return self.executor.restart_flow(keep_context)
|
||||||
with (
|
with Hub.current.start_span(
|
||||||
Hub.current.start_span(
|
|
||||||
op="authentik.flow.stage.challenge_invalid",
|
op="authentik.flow.stage.challenge_invalid",
|
||||||
description=self.__class__.__name__,
|
description=self.__class__.__name__,
|
||||||
),
|
|
||||||
HIST_FLOWS_STAGE_TIME.labels(
|
|
||||||
stage_type=self.__class__.__name__, method="challenge_invalid"
|
|
||||||
).time(),
|
|
||||||
):
|
):
|
||||||
return self.challenge_invalid(challenge)
|
return self.challenge_invalid(challenge)
|
||||||
with (
|
with Hub.current.start_span(
|
||||||
Hub.current.start_span(
|
|
||||||
op="authentik.flow.stage.challenge_valid",
|
op="authentik.flow.stage.challenge_valid",
|
||||||
description=self.__class__.__name__,
|
description=self.__class__.__name__,
|
||||||
),
|
|
||||||
HIST_FLOWS_STAGE_TIME.labels(
|
|
||||||
stage_type=self.__class__.__name__, method="challenge_valid"
|
|
||||||
).time(),
|
|
||||||
):
|
):
|
||||||
return self.challenge_valid(challenge)
|
return self.challenge_valid(challenge)
|
||||||
|
|
||||||
@ -151,14 +135,9 @@ class ChallengeStageView(StageView):
|
|||||||
return self.executor.flow.title
|
return self.executor.flow.title
|
||||||
|
|
||||||
def _get_challenge(self, *args, **kwargs) -> Challenge:
|
def _get_challenge(self, *args, **kwargs) -> Challenge:
|
||||||
with (
|
with Hub.current.start_span(
|
||||||
Hub.current.start_span(
|
|
||||||
op="authentik.flow.stage.get_challenge",
|
op="authentik.flow.stage.get_challenge",
|
||||||
description=self.__class__.__name__,
|
description=self.__class__.__name__,
|
||||||
),
|
|
||||||
HIST_FLOWS_STAGE_TIME.labels(
|
|
||||||
stage_type=self.__class__.__name__, method="get_challenge"
|
|
||||||
).time(),
|
|
||||||
):
|
):
|
||||||
challenge = self.get_challenge(*args, **kwargs)
|
challenge = self.get_challenge(*args, **kwargs)
|
||||||
with Hub.current.start_span(
|
with Hub.current.start_span(
|
||||||
@ -231,7 +210,7 @@ class AccessDeniedChallengeView(ChallengeStageView):
|
|||||||
def get_challenge(self, *args, **kwargs) -> Challenge:
|
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||||
return AccessDeniedChallenge(
|
return AccessDeniedChallenge(
|
||||||
data={
|
data={
|
||||||
"error_message": str(self.error_message or "Unknown error"),
|
"error_message": self.error_message or "Unknown error",
|
||||||
"type": ChallengeTypes.NATIVE.value,
|
"type": ChallengeTypes.NATIVE.value,
|
||||||
"component": "ak-stage-access-denied",
|
"component": "ak-stage-access-denied",
|
||||||
}
|
}
|
||||||
|
@ -561,13 +561,9 @@ class ConfigureFlowInitView(LoginRequiredMixin, View):
|
|||||||
LOGGER.debug("Stage has no configure_flow set", stage=stage)
|
LOGGER.debug("Stage has no configure_flow set", stage=stage)
|
||||||
raise Http404
|
raise Http404
|
||||||
|
|
||||||
try:
|
|
||||||
plan = FlowPlanner(stage.configure_flow).plan(
|
plan = FlowPlanner(stage.configure_flow).plan(
|
||||||
request, {PLAN_CONTEXT_PENDING_USER: request.user}
|
request, {PLAN_CONTEXT_PENDING_USER: request.user}
|
||||||
)
|
)
|
||||||
except FlowNonApplicableException:
|
|
||||||
LOGGER.warning("Flow not applicable to user")
|
|
||||||
raise Http404
|
|
||||||
request.session[SESSION_KEY_PLAN] = plan
|
request.session[SESSION_KEY_PLAN] = plan
|
||||||
return redirect_with_qs(
|
return redirect_with_qs(
|
||||||
"authentik_core:if-flow",
|
"authentik_core:if-flow",
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
"""Avatar utils"""
|
"""Avatar utils"""
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from functools import cache as funccache
|
from functools import cache
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from django.core.cache import cache
|
|
||||||
from django.templatetags.static import static
|
from django.templatetags.static import static
|
||||||
from lxml import etree # nosec
|
from lxml import etree # nosec
|
||||||
from lxml.etree import Element, SubElement # nosec
|
from lxml.etree import Element, SubElement # nosec
|
||||||
@ -16,7 +15,6 @@ from authentik.lib.utils.http import get_http_session
|
|||||||
|
|
||||||
GRAVATAR_URL = "https://secure.gravatar.com"
|
GRAVATAR_URL = "https://secure.gravatar.com"
|
||||||
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
||||||
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
@ -52,24 +50,22 @@ def avatar_mode_gravatar(user: "User", mode: str) -> Optional[str]:
|
|||||||
parameters = [("size", "158"), ("rating", "g"), ("default", "404")]
|
parameters = [("size", "158"), ("rating", "g"), ("default", "404")]
|
||||||
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
|
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
|
||||||
|
|
||||||
full_key = CACHE_KEY_GRAVATAR + mail_hash
|
@cache
|
||||||
if cache.has_key(full_key):
|
def check_non_default(url: str):
|
||||||
cache.touch(full_key)
|
"""Cache HEAD check, based on URL"""
|
||||||
return cache.get(full_key)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Since we specify a default of 404, do a HEAD request
|
# Since we specify a default of 404, do a HEAD request
|
||||||
# (HEAD since we don't need the body)
|
# (HEAD since we don't need the body)
|
||||||
# so if that returns a 404, move onto the next mode
|
# so if that returns a 404, move onto the next mode
|
||||||
res = get_http_session().head(gravatar_url, timeout=5)
|
res = get_http_session().head(url, timeout=5)
|
||||||
if res.status_code == 404:
|
if res.status_code == 404:
|
||||||
cache.set(full_key, None)
|
|
||||||
return None
|
return None
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
except RequestException:
|
except RequestException:
|
||||||
return gravatar_url
|
return url
|
||||||
cache.set(full_key, gravatar_url)
|
return url
|
||||||
return gravatar_url
|
|
||||||
|
return check_non_default(gravatar_url)
|
||||||
|
|
||||||
|
|
||||||
def generate_colors(text: str) -> tuple[str, str]:
|
def generate_colors(text: str) -> tuple[str, str]:
|
||||||
@ -87,7 +83,7 @@ def generate_colors(text: str) -> tuple[str, str]:
|
|||||||
return bg_hex, text_hex
|
return bg_hex, text_hex
|
||||||
|
|
||||||
|
|
||||||
@funccache
|
@cache
|
||||||
# pylint: disable=too-many-arguments,too-many-locals
|
# pylint: disable=too-many-arguments,too-many-locals
|
||||||
def generate_avatar_from_name(
|
def generate_avatar_from_name(
|
||||||
name: str,
|
name: str,
|
||||||
@ -154,7 +150,7 @@ def generate_avatar_from_name(
|
|||||||
|
|
||||||
def avatar_mode_generated(user: "User", mode: str) -> Optional[str]:
|
def avatar_mode_generated(user: "User", mode: str) -> Optional[str]:
|
||||||
"""Wrapper that converts generated avatar to base64 svg"""
|
"""Wrapper that converts generated avatar to base64 svg"""
|
||||||
svg = generate_avatar_from_name(user.name if user.name.strip() != "" else "a k")
|
svg = generate_avatar_from_name(user.name if user.name != "" else "a k")
|
||||||
return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}"
|
return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
"""authentik expression policy evaluator"""
|
"""authentik expression policy evaluator"""
|
||||||
import re
|
import re
|
||||||
import socket
|
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_address, ip_network
|
||||||
from textwrap import indent
|
from textwrap import indent
|
||||||
from typing import Any, Iterable, Optional
|
from typing import Any, Iterable, Optional
|
||||||
|
|
||||||
from cachetools import TLRUCache, cached
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django_otp import devices_for_user
|
from django_otp import devices_for_user
|
||||||
from rest_framework.serializers import ValidationError
|
from rest_framework.serializers import ValidationError
|
||||||
@ -43,8 +41,6 @@ class BaseEvaluator:
|
|||||||
"ak_is_group_member": BaseEvaluator.expr_is_group_member,
|
"ak_is_group_member": BaseEvaluator.expr_is_group_member,
|
||||||
"ak_user_by": BaseEvaluator.expr_user_by,
|
"ak_user_by": BaseEvaluator.expr_user_by,
|
||||||
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
|
"ak_user_has_authenticator": BaseEvaluator.expr_func_user_has_authenticator,
|
||||||
"resolve_dns": BaseEvaluator.expr_resolve_dns,
|
|
||||||
"reverse_dns": BaseEvaluator.expr_reverse_dns,
|
|
||||||
"ak_create_event": self.expr_event_create,
|
"ak_create_event": self.expr_event_create,
|
||||||
"ak_logger": get_logger(self._filename).bind(),
|
"ak_logger": get_logger(self._filename).bind(),
|
||||||
"requests": get_http_session(),
|
"requests": get_http_session(),
|
||||||
@ -53,39 +49,6 @@ class BaseEvaluator:
|
|||||||
}
|
}
|
||||||
self._context = {}
|
self._context = {}
|
||||||
|
|
||||||
@cached(cache=TLRUCache(maxsize=32, ttu=lambda key, value, now: now + 180))
|
|
||||||
@staticmethod
|
|
||||||
def expr_resolve_dns(host: str, ip_version: Optional[int] = None) -> list[str]:
|
|
||||||
"""Resolve host to a list of IPv4 and/or IPv6 addresses."""
|
|
||||||
# Although it seems to be fine (raising OSError), docs warn
|
|
||||||
# against passing `None` for both the host and the port
|
|
||||||
# https://docs.python.org/3/library/socket.html#socket.getaddrinfo
|
|
||||||
host = host or ""
|
|
||||||
|
|
||||||
ip_list = []
|
|
||||||
|
|
||||||
family = 0
|
|
||||||
if ip_version == 4:
|
|
||||||
family = socket.AF_INET
|
|
||||||
if ip_version == 6:
|
|
||||||
family = socket.AF_INET6
|
|
||||||
|
|
||||||
try:
|
|
||||||
for ip_addr in socket.getaddrinfo(host, None, family=family):
|
|
||||||
ip_list.append(str(ip_addr[4][0]))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
return list(set(ip_list))
|
|
||||||
|
|
||||||
@cached(cache=TLRUCache(maxsize=32, ttu=lambda key, value, now: now + 180))
|
|
||||||
@staticmethod
|
|
||||||
def expr_reverse_dns(ip_addr: str) -> str:
|
|
||||||
"""Perform a reverse DNS lookup."""
|
|
||||||
try:
|
|
||||||
return socket.getfqdn(ip_addr)
|
|
||||||
except OSError:
|
|
||||||
return ip_addr
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def expr_flatten(value: list[Any] | Any) -> Optional[Any]:
|
def expr_flatten(value: list[Any] | Any) -> Optional[Any]:
|
||||||
"""Flatten `value` if its a list"""
|
"""Flatten `value` if its a list"""
|
||||||
|
@ -1,58 +0,0 @@
|
|||||||
"""Migration helpers"""
|
|
||||||
from typing import Iterable
|
|
||||||
|
|
||||||
from django.apps.registry import Apps
|
|
||||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
|
||||||
|
|
||||||
|
|
||||||
def fallback_names(app: str, model: str, field: str):
|
|
||||||
"""Factory function that checks all instances of `app`.`model` instance's `field`
|
|
||||||
to prevent any duplicates"""
|
|
||||||
|
|
||||||
def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
|
||||||
db_alias = schema_editor.connection.alias
|
|
||||||
|
|
||||||
klass = apps.get_model(app, model)
|
|
||||||
seen_names = []
|
|
||||||
for obj in klass.objects.using(db_alias).all():
|
|
||||||
value = getattr(obj, field)
|
|
||||||
if value not in seen_names:
|
|
||||||
seen_names.append(value)
|
|
||||||
continue
|
|
||||||
new_value = value + "_2"
|
|
||||||
setattr(obj, field, new_value)
|
|
||||||
obj.save()
|
|
||||||
|
|
||||||
return migrator
|
|
||||||
|
|
||||||
|
|
||||||
def progress_bar(iterable: Iterable):
|
|
||||||
"""Call in a loop to create terminal progress bar
|
|
||||||
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console"""
|
|
||||||
|
|
||||||
prefix = "Writing: "
|
|
||||||
suffix = " finished"
|
|
||||||
decimals = 1
|
|
||||||
length = 100
|
|
||||||
fill = "█"
|
|
||||||
print_end = "\r"
|
|
||||||
|
|
||||||
total = len(iterable)
|
|
||||||
if total < 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
def print_progress_bar(iteration):
|
|
||||||
"""Progress Bar Printing Function"""
|
|
||||||
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
|
|
||||||
filled_length = int(length * iteration // total)
|
|
||||||
bar = fill * filled_length + "-" * (length - filled_length)
|
|
||||||
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
|
|
||||||
|
|
||||||
# Initial Call
|
|
||||||
print_progress_bar(0)
|
|
||||||
# Update Progress Bar
|
|
||||||
for i, item in enumerate(iterable):
|
|
||||||
yield item
|
|
||||||
print_progress_bar(i + 1)
|
|
||||||
# Print New Line on Complete
|
|
||||||
print()
|
|
@ -74,22 +74,3 @@ class DomainlessURLValidator(URLValidator):
|
|||||||
if scheme not in self.schemes:
|
if scheme not in self.schemes:
|
||||||
value = "default" + value
|
value = "default" + value
|
||||||
super().__call__(value)
|
super().__call__(value)
|
||||||
|
|
||||||
|
|
||||||
class DomainlessFormattedURLValidator(DomainlessURLValidator):
|
|
||||||
"""URL validator which allows for python format strings"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.formatter_re = r"([%\(\)a-zA-Z])*"
|
|
||||||
self.host_re = "(" + self.formatter_re + self.hostname_re + self.domain_re + "|localhost)"
|
|
||||||
self.regex = _lazy_re_compile(
|
|
||||||
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
|
|
||||||
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
|
|
||||||
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
|
|
||||||
r"(?::\d{2,5})?" # port
|
|
||||||
r"(?:[/?#][^\s]*)?" # resource path
|
|
||||||
r"\Z",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
self.schemes = ["http", "https", "blank"] + list(self.schemes)
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError
|
from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError
|
||||||
from celery.exceptions import CeleryError
|
from celery.exceptions import CeleryError
|
||||||
|
from channels.middleware import BaseMiddleware
|
||||||
from channels_redis.core import ChannelFull
|
from channels_redis.core import ChannelFull
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
|
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
|
||||||
@ -16,24 +17,37 @@ from ldap3.core.exceptions import LDAPException
|
|||||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||||
from redis.exceptions import RedisError, ResponseError
|
from redis.exceptions import RedisError, ResponseError
|
||||||
from rest_framework.exceptions import APIException
|
from rest_framework.exceptions import APIException
|
||||||
from sentry_sdk import HttpTransport
|
from sentry_sdk import HttpTransport, Hub
|
||||||
from sentry_sdk import init as sentry_sdk_init
|
from sentry_sdk import init as sentry_sdk_init
|
||||||
from sentry_sdk.api import set_tag
|
from sentry_sdk.api import set_tag
|
||||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||||
from sentry_sdk.integrations.django import DjangoIntegration
|
from sentry_sdk.integrations.django import DjangoIntegration
|
||||||
from sentry_sdk.integrations.redis import RedisIntegration
|
from sentry_sdk.integrations.redis import RedisIntegration
|
||||||
from sentry_sdk.integrations.threading import ThreadingIntegration
|
from sentry_sdk.integrations.threading import ThreadingIntegration
|
||||||
|
from sentry_sdk.tracing import Transaction
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
from websockets.exceptions import WebSocketException
|
from websockets.exceptions import WebSocketException
|
||||||
|
|
||||||
from authentik import __version__, get_build_hash
|
from authentik import __version__, get_build_hash
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.utils.http import authentik_user_agent
|
from authentik.lib.utils.http import authentik_user_agent
|
||||||
from authentik.lib.utils.reflection import get_env
|
from authentik.lib.utils.reflection import class_to_path, get_env
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class SentryWSMiddleware(BaseMiddleware):
|
||||||
|
"""Sentry Websocket middleweare to set the transaction name based on
|
||||||
|
consumer class path"""
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
transaction: Optional[Transaction] = Hub.current.scope.transaction
|
||||||
|
class_path = class_to_path(self.inner.consumer_class)
|
||||||
|
if transaction:
|
||||||
|
transaction.name = class_path
|
||||||
|
return await self.inner(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
class SentryIgnoredException(Exception):
|
class SentryIgnoredException(Exception):
|
||||||
"""Base Class for all errors that are suppressed, and not sent to sentry."""
|
"""Base Class for all errors that are suppressed, and not sent to sentry."""
|
||||||
|
|
||||||
@ -80,12 +94,9 @@ def sentry_init(**sentry_init_kwargs):
|
|||||||
def traces_sampler(sampling_context: dict) -> float:
|
def traces_sampler(sampling_context: dict) -> float:
|
||||||
"""Custom sampler to ignore certain routes"""
|
"""Custom sampler to ignore certain routes"""
|
||||||
path = sampling_context.get("asgi_scope", {}).get("path", "")
|
path = sampling_context.get("asgi_scope", {}).get("path", "")
|
||||||
_type = sampling_context.get("asgi_scope", {}).get("type", "")
|
|
||||||
# Ignore all healthcheck routes
|
# Ignore all healthcheck routes
|
||||||
if path.startswith("/-/health") or path.startswith("/-/metrics"):
|
if path.startswith("/-/health") or path.startswith("/-/metrics"):
|
||||||
return 0
|
return 0
|
||||||
if _type == "websocket":
|
|
||||||
return 0
|
|
||||||
return float(CONFIG.y("error_reporting.sample_rate", 0.1))
|
return float(CONFIG.y("error_reporting.sample_rate", 0.1))
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
"""Test utils"""
|
"""Test utils"""
|
||||||
from inspect import currentframe
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from django.contrib.messages.middleware import MessageMiddleware
|
from django.contrib.messages.middleware import MessageMiddleware
|
||||||
from django.contrib.sessions.middleware import SessionMiddleware
|
from django.contrib.sessions.middleware import SessionMiddleware
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
@ -14,21 +11,6 @@ def dummy_get_response(request: HttpRequest): # pragma: no cover
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_fixture(path: str, **kwargs) -> str:
|
|
||||||
"""Load fixture, optionally formatting it with kwargs"""
|
|
||||||
current = currentframe()
|
|
||||||
parent = current.f_back
|
|
||||||
calling_file_path = parent.f_globals["__file__"]
|
|
||||||
with open(
|
|
||||||
Path(calling_file_path).resolve().parent / Path(path), "r", encoding="utf-8"
|
|
||||||
) as _fixture:
|
|
||||||
fixture = _fixture.read()
|
|
||||||
try:
|
|
||||||
return fixture % kwargs
|
|
||||||
except TypeError:
|
|
||||||
return fixture
|
|
||||||
|
|
||||||
|
|
||||||
def get_request(*args, user=None, **kwargs):
|
def get_request(*args, user=None, **kwargs):
|
||||||
"""Get a request with usable session"""
|
"""Get a request with usable session"""
|
||||||
request = RequestFactory().get(*args, **kwargs)
|
request = RequestFactory().get(*args, **kwargs)
|
||||||
|
@ -38,17 +38,13 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
|
|||||||
if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
|
if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
|
||||||
return None
|
return None
|
||||||
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
|
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
|
||||||
token = (
|
tokens = Token.filter_not_expired(
|
||||||
Token.filter_not_expired(
|
|
||||||
key=request.META.get(OUTPOST_TOKEN_HEADER), intent=TokenIntents.INTENT_API
|
key=request.META.get(OUTPOST_TOKEN_HEADER), intent=TokenIntents.INTENT_API
|
||||||
)
|
)
|
||||||
.select_related("user")
|
if not tokens.exists():
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not token:
|
|
||||||
LOGGER.warning("Attempted remote-ip override without token", fake_ip=fake_ip)
|
LOGGER.warning("Attempted remote-ip override without token", fake_ip=fake_ip)
|
||||||
return None
|
return None
|
||||||
user = token.user
|
user = tokens.first().user
|
||||||
if not user.group_attributes(request).get(USER_ATTRIBUTE_CAN_OVERRIDE_IP, False):
|
if not user.group_attributes(request).get(USER_ATTRIBUTE_CAN_OVERRIDE_IP, False):
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"Remote-IP override: user doesn't have permission",
|
"Remote-IP override: user doesn't have permission",
|
||||||
|
@ -9,4 +9,4 @@ def get_lxml_parser():
|
|||||||
|
|
||||||
def lxml_from_string(text: str):
|
def lxml_from_string(text: str):
|
||||||
"""Wrapper around fromstring"""
|
"""Wrapper around fromstring"""
|
||||||
return fromstring(text, parser=get_lxml_parser()) # nosec
|
return fromstring(text, parser=get_lxml_parser())
|
||||||
|
@ -16,6 +16,7 @@ from authentik.outposts.controllers.k8s.triggers import NeedsRecreate, NeedsUpda
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.outposts.controllers.kubernetes import KubernetesController
|
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
T = TypeVar("T", V1Pod, V1Deployment)
|
T = TypeVar("T", V1Pod, V1Deployment)
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +56,6 @@ class KubernetesObjectReconciler(Generic[T]):
|
|||||||
}
|
}
|
||||||
).lower()
|
).lower()
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def up(self):
|
def up(self):
|
||||||
"""Create object if it doesn't exist, update if needed or recreate if needed."""
|
"""Create object if it doesn't exist, update if needed or recreate if needed."""
|
||||||
current = None
|
current = None
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
# Generated by Django 4.1.7 on 2023-03-07 13:41
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
from authentik.lib.migrations import fallback_names
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("authentik_outposts", "0018_kubernetesserviceconnection_verify_ssl"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.RunPython(fallback_names("authentik_outposts", "outpost", "name")),
|
|
||||||
migrations.RunPython(
|
|
||||||
fallback_names("authentik_outposts", "outpostserviceconnection", "name")
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="outpost",
|
|
||||||
name="name",
|
|
||||||
field=models.TextField(unique=True),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="outpostserviceconnection",
|
|
||||||
name="name",
|
|
||||||
field=models.TextField(unique=True),
|
|
||||||
),
|
|
||||||
]
|
|
@ -113,7 +113,7 @@ class OutpostServiceConnection(models.Model):
|
|||||||
"""Connection details for an Outpost Controller, like Docker or Kubernetes"""
|
"""Connection details for an Outpost Controller, like Docker or Kubernetes"""
|
||||||
|
|
||||||
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
||||||
name = models.TextField(unique=True)
|
name = models.TextField()
|
||||||
|
|
||||||
local = models.BooleanField(
|
local = models.BooleanField(
|
||||||
default=False,
|
default=False,
|
||||||
@ -239,7 +239,7 @@ class Outpost(SerializerModel, ManagedModel):
|
|||||||
"""Outpost instance which manages a service user and token"""
|
"""Outpost instance which manages a service user and token"""
|
||||||
|
|
||||||
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
||||||
name = models.TextField(unique=True)
|
name = models.TextField()
|
||||||
|
|
||||||
type = models.TextField(choices=OutpostType.choices, default=OutpostType.PROXY)
|
type = models.TextField(choices=OutpostType.choices, default=OutpostType.PROXY)
|
||||||
service_connection = InheritanceForeignKey(
|
service_connection = InheritanceForeignKey(
|
||||||
|
@ -19,9 +19,9 @@ CELERY_BEAT_SCHEDULE = {
|
|||||||
"schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
|
"schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
|
||||||
"options": {"queue": "authentik_scheduled"},
|
"options": {"queue": "authentik_scheduled"},
|
||||||
},
|
},
|
||||||
"outpost_connection_discovery": {
|
"outpost_local_connection": {
|
||||||
"task": "authentik.outposts.tasks.outpost_connection_discovery",
|
"task": "authentik.outposts.tasks.outpost_local_connection",
|
||||||
"schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"),
|
"schedule": crontab(minute=fqdn_rand("outpost_local_connection"), hour="*/8"),
|
||||||
"options": {"queue": "authentik_scheduled"},
|
"options": {"queue": "authentik_scheduled"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -236,33 +236,28 @@ def _outpost_single_update(outpost: Outpost):
|
|||||||
async_to_sync(closing_send)(channel, {"type": "event.update"})
|
async_to_sync(closing_send)(channel, {"type": "event.update"})
|
||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task(
|
@CELERY_APP.task()
|
||||||
base=MonitoredTask,
|
def outpost_local_connection():
|
||||||
bind=True,
|
|
||||||
)
|
|
||||||
def outpost_connection_discovery(self: MonitoredTask):
|
|
||||||
"""Checks the local environment and create Service connections."""
|
"""Checks the local environment and create Service connections."""
|
||||||
status = TaskResult(TaskResultStatus.SUCCESSFUL)
|
|
||||||
if not CONFIG.y_bool("outposts.discover"):
|
if not CONFIG.y_bool("outposts.discover"):
|
||||||
status.messages.append("Outpost integration discovery is disabled")
|
LOGGER.info("Outpost integration discovery is disabled")
|
||||||
self.set_status(status)
|
|
||||||
return
|
return
|
||||||
# Explicitly check against token filename, as that's
|
# Explicitly check against token filename, as that's
|
||||||
# only present when the integration is enabled
|
# only present when the integration is enabled
|
||||||
if Path(SERVICE_TOKEN_FILENAME).exists():
|
if Path(SERVICE_TOKEN_FILENAME).exists():
|
||||||
status.messages.append("Detected in-cluster Kubernetes Config")
|
LOGGER.info("Detected in-cluster Kubernetes Config")
|
||||||
if not KubernetesServiceConnection.objects.filter(local=True).exists():
|
if not KubernetesServiceConnection.objects.filter(local=True).exists():
|
||||||
status.messages.append("Created Service Connection for in-cluster")
|
LOGGER.debug("Created Service Connection for in-cluster")
|
||||||
KubernetesServiceConnection.objects.create(
|
KubernetesServiceConnection.objects.create(
|
||||||
name="Local Kubernetes Cluster", local=True, kubeconfig={}
|
name="Local Kubernetes Cluster", local=True, kubeconfig={}
|
||||||
)
|
)
|
||||||
# For development, check for the existence of a kubeconfig file
|
# For development, check for the existence of a kubeconfig file
|
||||||
kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
|
kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
|
||||||
if kubeconfig_path.exists():
|
if kubeconfig_path.exists():
|
||||||
status.messages.append("Detected kubeconfig")
|
LOGGER.info("Detected kubeconfig")
|
||||||
kubeconfig_local_name = f"k8s-{gethostname()}"
|
kubeconfig_local_name = f"k8s-{gethostname()}"
|
||||||
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
|
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
|
||||||
status.messages.append("Creating kubeconfig Service Connection")
|
LOGGER.debug("Creating kubeconfig Service Connection")
|
||||||
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
|
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
|
||||||
KubernetesServiceConnection.objects.create(
|
KubernetesServiceConnection.objects.create(
|
||||||
name=kubeconfig_local_name,
|
name=kubeconfig_local_name,
|
||||||
@ -271,12 +266,11 @@ def outpost_connection_discovery(self: MonitoredTask):
|
|||||||
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
|
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
|
||||||
socket = Path(unix_socket_path)
|
socket = Path(unix_socket_path)
|
||||||
if socket.exists() and access(socket, R_OK):
|
if socket.exists() and access(socket, R_OK):
|
||||||
status.messages.append("Detected local docker socket")
|
LOGGER.info("Detected local docker socket")
|
||||||
if len(DockerServiceConnection.objects.filter(local=True)) == 0:
|
if len(DockerServiceConnection.objects.filter(local=True)) == 0:
|
||||||
status.messages.append("Created Service Connection for docker")
|
LOGGER.debug("Created Service Connection for docker")
|
||||||
DockerServiceConnection.objects.create(
|
DockerServiceConnection.objects.create(
|
||||||
name="Local Docker connection",
|
name="Local Docker connection",
|
||||||
local=True,
|
local=True,
|
||||||
url=unix_socket_path,
|
url=unix_socket_path,
|
||||||
)
|
)
|
||||||
self.set_status(status)
|
|
||||||
|
@ -4,7 +4,6 @@ from rest_framework.test import APITestCase
|
|||||||
|
|
||||||
from authentik.core.models import PropertyMapping
|
from authentik.core.models import PropertyMapping
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
from authentik.lib.generators import generate_id
|
|
||||||
from authentik.outposts.api.outposts import OutpostSerializer
|
from authentik.outposts.api.outposts import OutpostSerializer
|
||||||
from authentik.outposts.models import OutpostType, default_outpost_config
|
from authentik.outposts.models import OutpostType, default_outpost_config
|
||||||
from authentik.providers.ldap.models import LDAPProvider
|
from authentik.providers.ldap.models import LDAPProvider
|
||||||
@ -17,7 +16,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.mapping = PropertyMapping.objects.create(
|
self.mapping = PropertyMapping.objects.create(
|
||||||
name=generate_id(), expression="""return {'foo': 'bar'}"""
|
name="dummy", expression="""return {'foo': 'bar'}"""
|
||||||
)
|
)
|
||||||
self.user = create_test_admin_user()
|
self.user = create_test_admin_user()
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
@ -26,12 +25,12 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||||||
"""Test Outpost validation"""
|
"""Test Outpost validation"""
|
||||||
valid = OutpostSerializer(
|
valid = OutpostSerializer(
|
||||||
data={
|
data={
|
||||||
"name": generate_id(),
|
"name": "foo",
|
||||||
"type": OutpostType.PROXY,
|
"type": OutpostType.PROXY,
|
||||||
"config": default_outpost_config(),
|
"config": default_outpost_config(),
|
||||||
"providers": [
|
"providers": [
|
||||||
ProxyProvider.objects.create(
|
ProxyProvider.objects.create(
|
||||||
name=generate_id(), authorization_flow=create_test_flow()
|
name="test", authorization_flow=create_test_flow()
|
||||||
).pk
|
).pk
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -39,12 +38,12 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||||||
self.assertTrue(valid.is_valid())
|
self.assertTrue(valid.is_valid())
|
||||||
invalid = OutpostSerializer(
|
invalid = OutpostSerializer(
|
||||||
data={
|
data={
|
||||||
"name": generate_id(),
|
"name": "foo",
|
||||||
"type": OutpostType.PROXY,
|
"type": OutpostType.PROXY,
|
||||||
"config": default_outpost_config(),
|
"config": default_outpost_config(),
|
||||||
"providers": [
|
"providers": [
|
||||||
LDAPProvider.objects.create(
|
LDAPProvider.objects.create(
|
||||||
name=generate_id(), authorization_flow=create_test_flow()
|
name="test", authorization_flow=create_test_flow()
|
||||||
).pk
|
).pk
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -61,19 +60,15 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||||||
|
|
||||||
def test_outpost_config(self):
|
def test_outpost_config(self):
|
||||||
"""Test Outpost's config field"""
|
"""Test Outpost's config field"""
|
||||||
provider = ProxyProvider.objects.create(
|
provider = ProxyProvider.objects.create(name="test", authorization_flow=create_test_flow())
|
||||||
name=generate_id(), authorization_flow=create_test_flow()
|
invalid = OutpostSerializer(data={"name": "foo", "providers": [provider.pk], "config": ""})
|
||||||
)
|
|
||||||
invalid = OutpostSerializer(
|
|
||||||
data={"name": generate_id(), "providers": [provider.pk], "config": ""}
|
|
||||||
)
|
|
||||||
self.assertFalse(invalid.is_valid())
|
self.assertFalse(invalid.is_valid())
|
||||||
self.assertIn("config", invalid.errors)
|
self.assertIn("config", invalid.errors)
|
||||||
valid = OutpostSerializer(
|
valid = OutpostSerializer(
|
||||||
data={
|
data={
|
||||||
"name": generate_id(),
|
"name": "foo",
|
||||||
"providers": [provider.pk],
|
"providers": [provider.pk],
|
||||||
"config": default_outpost_config(generate_id()),
|
"config": default_outpost_config("foo"),
|
||||||
"type": OutpostType.PROXY,
|
"type": OutpostType.PROXY,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -7,6 +7,11 @@ GAUGE_POLICIES_CACHED = Gauge(
|
|||||||
"authentik_policies_cached",
|
"authentik_policies_cached",
|
||||||
"Cached Policies",
|
"Cached Policies",
|
||||||
)
|
)
|
||||||
|
HIST_POLICIES_BUILD_TIME = Histogram(
|
||||||
|
"authentik_policies_build_time",
|
||||||
|
"Execution times complete policy result to an object",
|
||||||
|
["object_pk", "object_type"],
|
||||||
|
)
|
||||||
|
|
||||||
HIST_POLICIES_EXECUTION_TIME = Histogram(
|
HIST_POLICIES_EXECUTION_TIME = Histogram(
|
||||||
"authentik_policies_execution_time",
|
"authentik_policies_execution_time",
|
||||||
|
@ -10,6 +10,7 @@ from sentry_sdk.tracing import Span
|
|||||||
from structlog.stdlib import BoundLogger, get_logger
|
from structlog.stdlib import BoundLogger, get_logger
|
||||||
|
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
|
from authentik.policies.apps import HIST_POLICIES_BUILD_TIME
|
||||||
from authentik.policies.exceptions import PolicyEngineException
|
from authentik.policies.exceptions import PolicyEngineException
|
||||||
from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel, PolicyEngineMode
|
from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel, PolicyEngineMode
|
||||||
from authentik.policies.process import PolicyProcess, cache_key
|
from authentik.policies.process import PolicyProcess, cache_key
|
||||||
@ -85,6 +86,10 @@ class PolicyEngine:
|
|||||||
op="authentik.policy.engine.build",
|
op="authentik.policy.engine.build",
|
||||||
description=self.__pbm,
|
description=self.__pbm,
|
||||||
) as span,
|
) as span,
|
||||||
|
HIST_POLICIES_BUILD_TIME.labels(
|
||||||
|
object_pk=str(self.__pbm.pk),
|
||||||
|
object_type=f"{self.__pbm._meta.app_label}.{self.__pbm._meta.model_name}",
|
||||||
|
).time(),
|
||||||
):
|
):
|
||||||
span: Span
|
span: Span
|
||||||
span.set_data("pbm", self.__pbm)
|
span.set_data("pbm", self.__pbm)
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
# Generated by Django 4.1.7 on 2023-03-07 13:41
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
from authentik.lib.migrations import fallback_names
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("authentik_policies", "0009_alter_policy_name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.RunPython(fallback_names("authentik_policies", "policy", "name")),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="policy",
|
|
||||||
name="name",
|
|
||||||
field=models.TextField(unique=True),
|
|
||||||
),
|
|
||||||
]
|
|
@ -158,7 +158,7 @@ class Policy(SerializerModel, CreatedUpdatedModel):
|
|||||||
|
|
||||||
policy_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
policy_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
|
|
||||||
name = models.TextField(unique=True)
|
name = models.TextField()
|
||||||
|
|
||||||
execution_logging = models.BooleanField(
|
execution_logging = models.BooleanField(
|
||||||
default=False,
|
default=False,
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.models import User
|
||||||
from authentik.lib.generators import generate_id
|
|
||||||
from authentik.policies.dummy.models import DummyPolicy
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
from authentik.policies.exceptions import PolicyEngineException
|
from authentik.policies.exceptions import PolicyEngineException
|
||||||
@ -18,17 +17,11 @@ class TestPolicyEngine(TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
clear_policy_cache()
|
clear_policy_cache()
|
||||||
self.user = create_test_admin_user()
|
self.user = User.objects.create_user(username="policyuser")
|
||||||
self.policy_false = DummyPolicy.objects.create(
|
self.policy_false = DummyPolicy.objects.create(result=False, wait_min=0, wait_max=1)
|
||||||
name=generate_id(), result=False, wait_min=0, wait_max=1
|
self.policy_true = DummyPolicy.objects.create(result=True, wait_min=0, wait_max=1)
|
||||||
)
|
self.policy_wrong_type = Policy.objects.create(name="wrong_type")
|
||||||
self.policy_true = DummyPolicy.objects.create(
|
self.policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
|
||||||
name=generate_id(), result=True, wait_min=0, wait_max=1
|
|
||||||
)
|
|
||||||
self.policy_wrong_type = Policy.objects.create(name=generate_id())
|
|
||||||
self.policy_raises = ExpressionPolicy.objects.create(
|
|
||||||
name=generate_id(), expression="{{ 0/0 }}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_engine_empty(self):
|
def test_engine_empty(self):
|
||||||
"""Ensure empty policy list passes"""
|
"""Ensure empty policy list passes"""
|
||||||
|
@ -26,7 +26,6 @@ class LDAPProviderSerializer(ProviderSerializer):
|
|||||||
"search_mode",
|
"search_mode",
|
||||||
"bind_mode",
|
"bind_mode",
|
||||||
]
|
]
|
||||||
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class LDAPProviderViewSet(UsedByMixin, ModelViewSet):
|
class LDAPProviderViewSet(UsedByMixin, ModelViewSet):
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""OAuth2Provider API Views"""
|
"""OAuth2Provider API Views"""
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils import timezone
|
|
||||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import CharField
|
from rest_framework.fields import CharField
|
||||||
@ -39,7 +38,6 @@ class OAuth2ProviderSerializer(ProviderSerializer):
|
|||||||
"issuer_mode",
|
"issuer_mode",
|
||||||
"jwks_sources",
|
"jwks_sources",
|
||||||
]
|
]
|
||||||
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2ProviderSetupURLs(PassiveSerializer):
|
class OAuth2ProviderSetupURLs(PassiveSerializer):
|
||||||
@ -155,7 +153,6 @@ class OAuth2ProviderViewSet(UsedByMixin, ModelViewSet):
|
|||||||
user=request.user,
|
user=request.user,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
_scope=" ".join(scope_names),
|
_scope=" ".join(scope_names),
|
||||||
auth_time=timezone.now(),
|
|
||||||
),
|
),
|
||||||
request,
|
request,
|
||||||
)
|
)
|
||||||
|
@ -141,20 +141,15 @@ class AuthorizeError(OAuth2Error):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redirect_uri: str,
|
redirect_uri: str,
|
||||||
error: str,
|
error: str,
|
||||||
grant_type: str,
|
grant_type: str,
|
||||||
state: str,
|
state: str,
|
||||||
description: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.error = error
|
self.error = error
|
||||||
if description:
|
|
||||||
self.description = description
|
|
||||||
else:
|
|
||||||
self.description = self.errors[error]
|
self.description = self.errors[error]
|
||||||
self.redirect_uri = redirect_uri
|
self.redirect_uri = redirect_uri
|
||||||
self.grant_type = grant_type
|
self.grant_type = grant_type
|
||||||
@ -174,12 +169,10 @@ class AuthorizeError(OAuth2Error):
|
|||||||
|
|
||||||
# See:
|
# See:
|
||||||
# http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError
|
# http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError
|
||||||
fragment_or_query = (
|
hash_or_question = "#" if self.grant_type == GrantTypes.IMPLICIT else "?"
|
||||||
"#" if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID] else "?"
|
|
||||||
)
|
|
||||||
|
|
||||||
uri = (
|
uri = (
|
||||||
f"{self.redirect_uri}{fragment_or_query}error="
|
f"{self.redirect_uri}{hash_or_question}error="
|
||||||
f"{self.error}&error_description={description}"
|
f"{self.error}&error_description={description}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -110,11 +110,12 @@ class IDToken:
|
|||||||
# Convert datetimes into timestamps.
|
# Convert datetimes into timestamps.
|
||||||
now = timezone.now()
|
now = timezone.now()
|
||||||
id_token.iat = int(now.timestamp())
|
id_token.iat = int(now.timestamp())
|
||||||
id_token.auth_time = int(token.auth_time.timestamp())
|
|
||||||
|
|
||||||
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
|
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
|
||||||
auth_event = get_login_event(request)
|
auth_event = get_login_event(request)
|
||||||
if auth_event:
|
if auth_event:
|
||||||
|
auth_time = auth_event.created
|
||||||
|
id_token.auth_time = int(auth_time.timestamp())
|
||||||
# Also check which method was used for authentication
|
# Also check which method was used for authentication
|
||||||
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
|
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
|
||||||
method_args = auth_event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
|
method_args = auth_event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||||
|
@ -1,40 +0,0 @@
|
|||||||
# Generated by Django 4.1.7 on 2023-02-22 22:23
|
|
||||||
|
|
||||||
import django.utils.timezone
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("authentik_providers_oauth2", "0014_alter_refreshtoken_options_and_more"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="accesstoken",
|
|
||||||
name="auth_time",
|
|
||||||
field=models.DateTimeField(
|
|
||||||
default=django.utils.timezone.now,
|
|
||||||
verbose_name="Authentication time",
|
|
||||||
),
|
|
||||||
preserve_default=False,
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="authorizationcode",
|
|
||||||
name="auth_time",
|
|
||||||
field=models.DateTimeField(
|
|
||||||
default=django.utils.timezone.now,
|
|
||||||
verbose_name="Authentication time",
|
|
||||||
),
|
|
||||||
preserve_default=False,
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="refreshtoken",
|
|
||||||
name="auth_time",
|
|
||||||
field=models.DateTimeField(
|
|
||||||
default=django.utils.timezone.now,
|
|
||||||
verbose_name="Authentication time",
|
|
||||||
),
|
|
||||||
preserve_default=False,
|
|
||||||
),
|
|
||||||
]
|
|
@ -226,7 +226,7 @@ class OAuth2Provider(Provider):
|
|||||||
def get_issuer(self, request: HttpRequest) -> Optional[str]:
|
def get_issuer(self, request: HttpRequest) -> Optional[str]:
|
||||||
"""Get issuer, based on request"""
|
"""Get issuer, based on request"""
|
||||||
if self.issuer_mode == IssuerMode.GLOBAL:
|
if self.issuer_mode == IssuerMode.GLOBAL:
|
||||||
return request.build_absolute_uri(reverse("authentik_core:root-redirect"))
|
return request.build_absolute_uri("/")
|
||||||
try:
|
try:
|
||||||
url = reverse(
|
url = reverse(
|
||||||
"authentik_providers_oauth2:provider-root",
|
"authentik_providers_oauth2:provider-root",
|
||||||
@ -282,7 +282,6 @@ class BaseGrantModel(models.Model):
|
|||||||
user = models.ForeignKey(User, verbose_name=_("User"), on_delete=models.CASCADE)
|
user = models.ForeignKey(User, verbose_name=_("User"), on_delete=models.CASCADE)
|
||||||
revoked = models.BooleanField(default=False)
|
revoked = models.BooleanField(default=False)
|
||||||
_scope = models.TextField(default="", verbose_name=_("Scopes"))
|
_scope = models.TextField(default="", verbose_name=_("Scopes"))
|
||||||
auth_time = models.DateTimeField(verbose_name="Authentication time")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scope(self) -> list[str]:
|
def scope(self) -> list[str]:
|
||||||
|
@ -204,7 +204,6 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
"redirect_uri": "http://local.invalid/Foo",
|
"redirect_uri": "http://local.invalid/Foo",
|
||||||
"scope": "openid",
|
"scope": "openid",
|
||||||
"state": "foo",
|
"state": "foo",
|
||||||
"nonce": generate_id(),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -326,7 +325,6 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
"state": state,
|
"state": state,
|
||||||
"scope": "openid",
|
"scope": "openid",
|
||||||
"redirect_uri": "http://localhost",
|
"redirect_uri": "http://localhost",
|
||||||
"nonce": generate_id(),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
@ -355,62 +353,6 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
delta=5,
|
delta=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_full_fragment_code(self):
|
|
||||||
"""Test full authorization"""
|
|
||||||
flow = create_test_flow()
|
|
||||||
provider: OAuth2Provider = OAuth2Provider.objects.create(
|
|
||||||
name=generate_id(),
|
|
||||||
client_id="test",
|
|
||||||
client_secret=generate_key(),
|
|
||||||
authorization_flow=flow,
|
|
||||||
redirect_uris="http://localhost",
|
|
||||||
signing_key=self.keypair,
|
|
||||||
)
|
|
||||||
Application.objects.create(name="app", slug="app", provider=provider)
|
|
||||||
state = generate_id()
|
|
||||||
user = create_test_admin_user()
|
|
||||||
self.client.force_login(user)
|
|
||||||
with patch(
|
|
||||||
"authentik.providers.oauth2.id_token.get_login_event",
|
|
||||||
MagicMock(
|
|
||||||
return_value=Event(
|
|
||||||
action=EventAction.LOGIN,
|
|
||||||
context={PLAN_CONTEXT_METHOD: "password"},
|
|
||||||
created=now(),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
):
|
|
||||||
# Step 1, initiate params and get redirect to flow
|
|
||||||
self.client.get(
|
|
||||||
reverse("authentik_providers_oauth2:authorize"),
|
|
||||||
data={
|
|
||||||
"response_type": "code",
|
|
||||||
"response_mode": "fragment",
|
|
||||||
"client_id": "test",
|
|
||||||
"state": state,
|
|
||||||
"scope": "openid",
|
|
||||||
"redirect_uri": "http://localhost",
|
|
||||||
"nonce": generate_id(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response = self.client.get(
|
|
||||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
|
||||||
)
|
|
||||||
code: AuthorizationCode = AuthorizationCode.objects.filter(user=user).first()
|
|
||||||
self.assertJSONEqual(
|
|
||||||
response.content.decode(),
|
|
||||||
{
|
|
||||||
"component": "xak-flow-redirect",
|
|
||||||
"type": ChallengeTypes.REDIRECT.value,
|
|
||||||
"to": (f"http://localhost#code={code.code}" f"&state={state}"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertAlmostEqual(
|
|
||||||
code.expires.timestamp() - now().timestamp(),
|
|
||||||
timedelta_from_string(provider.access_code_validity).total_seconds(),
|
|
||||||
delta=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_full_form_post_id_token(self):
|
def test_full_form_post_id_token(self):
|
||||||
"""Test full authorization (form_post response)"""
|
"""Test full authorization (form_post response)"""
|
||||||
flow = create_test_flow()
|
flow = create_test_flow()
|
||||||
@ -436,7 +378,6 @@ class TestAuthorize(OAuthTestCase):
|
|||||||
"state": state,
|
"state": state,
|
||||||
"scope": "openid",
|
"scope": "openid",
|
||||||
"redirect_uri": "http://localhost",
|
"redirect_uri": "http://localhost",
|
||||||
"nonce": generate_id(),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
|
@ -4,7 +4,6 @@ from base64 import b64encode
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||||
@ -42,7 +41,6 @@ class TesOAuth2Introspection(OAuthTestCase):
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=self.user,
|
user=self.user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
_scope="openid user profile",
|
_scope="openid user profile",
|
||||||
_id_token=json.dumps(
|
_id_token=json.dumps(
|
||||||
asdict(
|
asdict(
|
||||||
@ -74,7 +72,6 @@ class TesOAuth2Introspection(OAuthTestCase):
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=self.user,
|
user=self.user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
_scope="openid user profile",
|
_scope="openid user profile",
|
||||||
_id_token=json.dumps(
|
_id_token=json.dumps(
|
||||||
asdict(
|
asdict(
|
||||||
|
@ -4,7 +4,6 @@ from base64 import b64encode
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||||
@ -41,7 +40,6 @@ class TesOAuth2Revoke(OAuthTestCase):
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=self.user,
|
user=self.user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
_scope="openid user profile",
|
_scope="openid user profile",
|
||||||
_id_token=json.dumps(
|
_id_token=json.dumps(
|
||||||
asdict(
|
asdict(
|
||||||
@ -64,7 +62,6 @@ class TesOAuth2Revoke(OAuthTestCase):
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=self.user,
|
user=self.user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
_scope="openid user profile",
|
_scope="openid user profile",
|
||||||
_id_token=json.dumps(
|
_id_token=json.dumps(
|
||||||
asdict(
|
asdict(
|
||||||
|
@ -4,7 +4,6 @@ from json import dumps
|
|||||||
|
|
||||||
from django.test import RequestFactory
|
from django.test import RequestFactory
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
@ -46,9 +45,7 @@ class TestToken(OAuthTestCase):
|
|||||||
)
|
)
|
||||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||||
user = create_test_admin_user()
|
user = create_test_admin_user()
|
||||||
code = AuthorizationCode.objects.create(
|
code = AuthorizationCode.objects.create(code="foobar", provider=provider, user=user)
|
||||||
code="foobar", provider=provider, user=user, auth_time=timezone.now()
|
|
||||||
)
|
|
||||||
request = self.factory.post(
|
request = self.factory.post(
|
||||||
"/",
|
"/",
|
||||||
data={
|
data={
|
||||||
@ -102,7 +99,6 @@ class TestToken(OAuthTestCase):
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
user=user,
|
user=user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
)
|
)
|
||||||
request = self.factory.post(
|
request = self.factory.post(
|
||||||
"/",
|
"/",
|
||||||
@ -131,9 +127,7 @@ class TestToken(OAuthTestCase):
|
|||||||
self.app.save()
|
self.app.save()
|
||||||
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||||
user = create_test_admin_user()
|
user = create_test_admin_user()
|
||||||
code = AuthorizationCode.objects.create(
|
code = AuthorizationCode.objects.create(code="foobar", provider=provider, user=user)
|
||||||
code="foobar", provider=provider, user=user, auth_time=timezone.now()
|
|
||||||
)
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse("authentik_providers_oauth2:token"),
|
reverse("authentik_providers_oauth2:token"),
|
||||||
data={
|
data={
|
||||||
@ -179,7 +173,6 @@ class TestToken(OAuthTestCase):
|
|||||||
user=user,
|
user=user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
_id_token=dumps({}),
|
_id_token=dumps({}),
|
||||||
auth_time=timezone.now(),
|
|
||||||
)
|
)
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse("authentik_providers_oauth2:token"),
|
reverse("authentik_providers_oauth2:token"),
|
||||||
@ -228,7 +221,6 @@ class TestToken(OAuthTestCase):
|
|||||||
user=user,
|
user=user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
_id_token=dumps({}),
|
_id_token=dumps({}),
|
||||||
auth_time=timezone.now(),
|
|
||||||
)
|
)
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse("authentik_providers_oauth2:token"),
|
reverse("authentik_providers_oauth2:token"),
|
||||||
@ -279,7 +271,6 @@ class TestToken(OAuthTestCase):
|
|||||||
user=user,
|
user=user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
_id_token=dumps({}),
|
_id_token=dumps({}),
|
||||||
auth_time=timezone.now(),
|
|
||||||
)
|
)
|
||||||
# Create initial refresh token
|
# Create initial refresh token
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
|
@ -3,7 +3,6 @@ import json
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from authentik.blueprints.tests import apply_blueprint
|
from authentik.blueprints.tests import apply_blueprint
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
@ -38,7 +37,6 @@ class TestUserinfo(OAuthTestCase):
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=self.user,
|
user=self.user,
|
||||||
token=generate_id(),
|
token=generate_id(),
|
||||||
auth_time=timezone.now(),
|
|
||||||
_scope="openid user profile",
|
_scope="openid user profile",
|
||||||
_id_token=json.dumps(
|
_id_token=json.dumps(
|
||||||
asdict(
|
asdict(
|
||||||
@ -58,6 +56,7 @@ class TestUserinfo(OAuthTestCase):
|
|||||||
{
|
{
|
||||||
"name": self.user.name,
|
"name": self.user.name,
|
||||||
"given_name": self.user.name,
|
"given_name": self.user.name,
|
||||||
|
"family_name": "",
|
||||||
"preferred_username": self.user.name,
|
"preferred_username": self.user.name,
|
||||||
"nickname": self.user.name,
|
"nickname": self.user.name,
|
||||||
"groups": [group.name for group in self.user.ak_groups.all()],
|
"groups": [group.name for group in self.user.ak_groups.all()],
|
||||||
@ -80,6 +79,7 @@ class TestUserinfo(OAuthTestCase):
|
|||||||
{
|
{
|
||||||
"name": self.user.name,
|
"name": self.user.name,
|
||||||
"given_name": self.user.name,
|
"given_name": self.user.name,
|
||||||
|
"family_name": "",
|
||||||
"preferred_username": self.user.name,
|
"preferred_username": self.user.name,
|
||||||
"nickname": self.user.name,
|
"nickname": self.user.name,
|
||||||
"groups": [group.name for group in self.user.ak_groups.all()],
|
"groups": [group.name for group in self.user.ak_groups.all()],
|
||||||
|
@ -42,7 +42,7 @@ urlpatterns = [
|
|||||||
path("<slug:application_slug>/jwks/", JWKSView.as_view(), name="jwks"),
|
path("<slug:application_slug>/jwks/", JWKSView.as_view(), name="jwks"),
|
||||||
path(
|
path(
|
||||||
"<slug:application_slug>/",
|
"<slug:application_slug>/",
|
||||||
RedirectView.as_view(pattern_name="authentik_providers_oauth2:provider-info"),
|
RedirectView.as_view(pattern_name="authentk_providers_oauth2:provider-info"),
|
||||||
name="provider-root",
|
name="provider-root",
|
||||||
),
|
),
|
||||||
path(
|
path(
|
||||||
|
@ -146,10 +146,9 @@ def protected_resource_view(scopes: list[str]):
|
|||||||
LOGGER.warning("Revoked token was used", access_token=access_token)
|
LOGGER.warning("Revoked token was used", access_token=access_token)
|
||||||
Event.new(
|
Event.new(
|
||||||
action=EventAction.SUSPICIOUS_REQUEST,
|
action=EventAction.SUSPICIOUS_REQUEST,
|
||||||
message="Revoked access token was used",
|
message="Revoked refresh token was used",
|
||||||
token=token,
|
token=access_token,
|
||||||
provider=token.provider,
|
).from_http(request)
|
||||||
).from_http(request, user=token.user)
|
|
||||||
raise BearerTokenError("invalid_token")
|
raise BearerTokenError("invalid_token")
|
||||||
|
|
||||||
if not set(scopes).issubset(set(token.scope)):
|
if not set(scopes).issubset(set(token.scope)):
|
||||||
|
@ -17,14 +17,13 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.events.signals import get_login_event
|
from authentik.events.utils import get_user
|
||||||
from authentik.flows.challenge import (
|
from authentik.flows.challenge import (
|
||||||
PLAN_CONTEXT_TITLE,
|
PLAN_CONTEXT_TITLE,
|
||||||
AutosubmitChallenge,
|
AutosubmitChallenge,
|
||||||
ChallengeTypes,
|
ChallengeTypes,
|
||||||
HttpChallengeResponse,
|
HttpChallengeResponse,
|
||||||
)
|
)
|
||||||
from authentik.flows.exceptions import FlowNonApplicableException
|
|
||||||
from authentik.flows.models import in_memory_stage
|
from authentik.flows.models import in_memory_stage
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
|
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
|
||||||
from authentik.flows.stage import StageView
|
from authentik.flows.stage import StageView
|
||||||
@ -65,11 +64,12 @@ from authentik.stages.consent.stage import (
|
|||||||
PLAN_CONTEXT_CONSENT_PERMISSIONS,
|
PLAN_CONTEXT_CONSENT_PERMISSIONS,
|
||||||
ConsentStageView,
|
ConsentStageView,
|
||||||
)
|
)
|
||||||
|
from authentik.stages.user_login.stage import USER_LOGIN_AUTHENTICATED
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
PLAN_CONTEXT_PARAMS = "params"
|
PLAN_CONTEXT_PARAMS = "params"
|
||||||
SESSION_KEY_LAST_LOGIN_UID = "authentik/providers/oauth2/last_login_uid"
|
SESSION_KEY_NEEDS_LOGIN = "authentik/providers/oauth2/needs_login"
|
||||||
|
|
||||||
ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN}
|
ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN}
|
||||||
|
|
||||||
@ -158,14 +158,13 @@ class OAuthAuthorizationParams:
|
|||||||
request=query_dict.get("request", None),
|
request=query_dict.get("request", None),
|
||||||
max_age=int(max_age) if max_age else None,
|
max_age=int(max_age) if max_age else None,
|
||||||
code_challenge=query_dict.get("code_challenge"),
|
code_challenge=query_dict.get("code_challenge"),
|
||||||
code_challenge_method=query_dict.get("code_challenge_method", "plain"),
|
code_challenge_method=query_dict.get("code_challenge_method"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.provider: OAuth2Provider = OAuth2Provider.objects.filter(
|
try:
|
||||||
client_id=self.client_id
|
self.provider: OAuth2Provider = OAuth2Provider.objects.get(client_id=self.client_id)
|
||||||
).first()
|
except OAuth2Provider.DoesNotExist:
|
||||||
if not self.provider:
|
|
||||||
LOGGER.warning("Invalid client identifier", client_id=self.client_id)
|
LOGGER.warning("Invalid client identifier", client_id=self.client_id)
|
||||||
raise ClientIdError(client_id=self.client_id)
|
raise ClientIdError(client_id=self.client_id)
|
||||||
self.check_redirect_uri()
|
self.check_redirect_uri()
|
||||||
@ -235,54 +234,40 @@ class OAuthAuthorizationParams:
|
|||||||
|
|
||||||
def check_nonce(self):
|
def check_nonce(self):
|
||||||
"""Nonce parameter validation."""
|
"""Nonce parameter validation."""
|
||||||
# nonce is required for all flows that return an id_token from the authorization endpoint,
|
# https://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDTValidation
|
||||||
# see https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest or
|
# Nonce is only required for Implicit flows
|
||||||
# https://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken and
|
if self.grant_type != GrantTypes.IMPLICIT:
|
||||||
# https://bitbucket.org/openid/connect/issues/972/nonce-requirement-in-hybrid-auth-request
|
|
||||||
if self.response_type not in [
|
|
||||||
ResponseTypes.ID_TOKEN,
|
|
||||||
ResponseTypes.ID_TOKEN_TOKEN,
|
|
||||||
ResponseTypes.CODE_ID_TOKEN,
|
|
||||||
ResponseTypes.CODE_ID_TOKEN_TOKEN,
|
|
||||||
]:
|
|
||||||
return
|
|
||||||
if SCOPE_OPENID not in self.scope:
|
|
||||||
return
|
return
|
||||||
if not self.nonce:
|
if not self.nonce:
|
||||||
|
self.nonce = self.state
|
||||||
|
LOGGER.warning("Using state as nonce for OpenID Request")
|
||||||
|
if not self.nonce:
|
||||||
|
if SCOPE_OPENID in self.scope:
|
||||||
LOGGER.warning("Missing nonce for OpenID Request")
|
LOGGER.warning("Missing nonce for OpenID Request")
|
||||||
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
raise AuthorizeError(
|
||||||
|
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||||
|
)
|
||||||
|
|
||||||
def check_code_challenge(self):
|
def check_code_challenge(self):
|
||||||
"""PKCE validation of the transformation method."""
|
"""PKCE validation of the transformation method."""
|
||||||
if self.code_challenge and self.code_challenge_method not in ["plain", "S256"]:
|
if self.code_challenge and self.code_challenge_method not in ["plain", "S256"]:
|
||||||
raise AuthorizeError(
|
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
||||||
self.redirect_uri,
|
|
||||||
"invalid_request",
|
|
||||||
self.grant_type,
|
|
||||||
self.state,
|
|
||||||
f"Unsupported challenge method {self.code_challenge_method}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_code(self, request: HttpRequest) -> AuthorizationCode:
|
def create_code(self, request: HttpRequest) -> AuthorizationCode:
|
||||||
"""Create an AuthorizationCode object for the request"""
|
"""Create an AuthorizationCode object for the request"""
|
||||||
auth_event = get_login_event(request)
|
code = AuthorizationCode()
|
||||||
|
code.user = request.user
|
||||||
|
code.provider = self.provider
|
||||||
|
|
||||||
now = timezone.now()
|
code.code = uuid4().hex
|
||||||
|
|
||||||
code = AuthorizationCode(
|
|
||||||
user=request.user,
|
|
||||||
provider=self.provider,
|
|
||||||
auth_time=auth_event.created if auth_event else now,
|
|
||||||
code=uuid4().hex,
|
|
||||||
expires=now + timedelta_from_string(self.provider.access_code_validity),
|
|
||||||
scope=self.scope,
|
|
||||||
nonce=self.nonce,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.code_challenge and self.code_challenge_method:
|
if self.code_challenge and self.code_challenge_method:
|
||||||
code.code_challenge = self.code_challenge
|
code.code_challenge = self.code_challenge
|
||||||
code.code_challenge_method = self.code_challenge_method
|
code.code_challenge_method = self.code_challenge_method
|
||||||
|
|
||||||
|
code.expires = timezone.now() + timedelta_from_string(self.provider.access_code_validity)
|
||||||
|
code.scope = self.scope
|
||||||
|
code.nonce = self.nonce
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
@ -317,6 +302,7 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
|||||||
self.params.grant_type,
|
self.params.grant_type,
|
||||||
self.params.state,
|
self.params.state,
|
||||||
)
|
)
|
||||||
|
error.to_event(redirect_uri=error.redirect_uri).from_http(self.request)
|
||||||
raise RequestValidationError(error.get_response(self.request))
|
raise RequestValidationError(error.get_response(self.request))
|
||||||
|
|
||||||
def resolve_provider_application(self):
|
def resolve_provider_application(self):
|
||||||
@ -336,45 +322,32 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
|||||||
|
|
||||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
"""Start FlowPLanner, return to flow executor shell"""
|
"""Start FlowPLanner, return to flow executor shell"""
|
||||||
# Require a login event to be set, otherwise make the user re-login
|
|
||||||
login_event = get_login_event(request)
|
|
||||||
if not login_event:
|
|
||||||
LOGGER.warning("request with no login event")
|
|
||||||
return self.handle_no_permission()
|
|
||||||
login_uid = str(login_event.pk)
|
|
||||||
# After we've checked permissions, and the user has access, check if we need
|
# After we've checked permissions, and the user has access, check if we need
|
||||||
# to re-authenticate the user
|
# to re-authenticate the user
|
||||||
if self.params.max_age:
|
if self.params.max_age:
|
||||||
# Attempt to check via the session's login event if set, otherwise we can't
|
current_age: timedelta = (
|
||||||
# check
|
timezone.now()
|
||||||
login_time = login_event.created
|
- Event.objects.filter(action=EventAction.LOGIN, user=get_user(self.request.user))
|
||||||
current_age: timedelta = timezone.now() - login_time
|
.latest("created")
|
||||||
if current_age.total_seconds() > self.params.max_age:
|
.created
|
||||||
LOGGER.debug(
|
|
||||||
"Triggering authentication as max_age requirement",
|
|
||||||
max_age=self.params.max_age,
|
|
||||||
ago=int(current_age.total_seconds()),
|
|
||||||
)
|
)
|
||||||
# Since we already need to re-authenticate the user, set the old login UID
|
if current_age.total_seconds() > self.params.max_age:
|
||||||
# in case this request has both max_age and prompt=login
|
|
||||||
self.request.session[SESSION_KEY_LAST_LOGIN_UID] = login_uid
|
|
||||||
return self.handle_no_permission()
|
return self.handle_no_permission()
|
||||||
# If prompt=login, we need to re-authenticate the user regardless
|
# If prompt=login, we need to re-authenticate the user regardless
|
||||||
# Check if we're not already doing the re-authentication
|
|
||||||
if PROMPT_LOGIN in self.params.prompt:
|
|
||||||
# No previous login UID saved, so save the current uid and trigger
|
|
||||||
# re-login, or previous login UID matches current one, so no re-login happened yet
|
|
||||||
if (
|
if (
|
||||||
SESSION_KEY_LAST_LOGIN_UID not in self.request.session
|
PROMPT_LOGIN in self.params.prompt
|
||||||
or login_uid == self.request.session[SESSION_KEY_LAST_LOGIN_UID]
|
and SESSION_KEY_NEEDS_LOGIN not in self.request.session
|
||||||
|
# To prevent the user from having to double login when prompt is set to login
|
||||||
|
# and the user has just signed it. This session variable is set in the UserLoginStage
|
||||||
|
# and is (quite hackily) removed from the session in applications's API's List method
|
||||||
|
and USER_LOGIN_AUTHENTICATED not in self.request.session
|
||||||
):
|
):
|
||||||
self.request.session[SESSION_KEY_LAST_LOGIN_UID] = login_uid
|
self.request.session[SESSION_KEY_NEEDS_LOGIN] = True
|
||||||
return self.handle_no_permission()
|
return self.handle_no_permission()
|
||||||
scope_descriptions = UserInfoView().get_scope_descriptions(self.params.scope)
|
scope_descriptions = UserInfoView().get_scope_descriptions(self.params.scope)
|
||||||
# Regardless, we start the planner and return to it
|
# Regardless, we start the planner and return to it
|
||||||
planner = FlowPlanner(self.provider.authorization_flow)
|
planner = FlowPlanner(self.provider.authorization_flow)
|
||||||
planner.allow_empty_flows = True
|
planner.allow_empty_flows = True
|
||||||
try:
|
|
||||||
plan = planner.plan(
|
plan = planner.plan(
|
||||||
self.request,
|
self.request,
|
||||||
{
|
{
|
||||||
@ -388,8 +361,6 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
|||||||
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
|
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except FlowNonApplicableException:
|
|
||||||
return self.handle_no_permission_authenticated()
|
|
||||||
# OpenID clients can specify a `prompt` parameter, and if its set to consent we
|
# OpenID clients can specify a `prompt` parameter, and if its set to consent we
|
||||||
# need to inject a consent stage
|
# need to inject a consent stage
|
||||||
if PROMPT_CONSENT in self.params.prompt:
|
if PROMPT_CONSENT in self.params.prompt:
|
||||||
@ -514,11 +485,6 @@ class OAuthFulfillmentStage(StageView):
|
|||||||
return urlunsplit(uri)
|
return urlunsplit(uri)
|
||||||
|
|
||||||
if self.params.response_mode == ResponseMode.FRAGMENT:
|
if self.params.response_mode == ResponseMode.FRAGMENT:
|
||||||
query_fragment = {}
|
|
||||||
if self.params.grant_type in [GrantTypes.AUTHORIZATION_CODE]:
|
|
||||||
query_fragment["code"] = code.code
|
|
||||||
query_fragment["state"] = [str(self.params.state) if self.params.state else ""]
|
|
||||||
else:
|
|
||||||
query_fragment = self.create_implicit_response(code)
|
query_fragment = self.create_implicit_response(code)
|
||||||
|
|
||||||
uri = uri._replace(
|
uri = uri._replace(
|
||||||
@ -552,7 +518,6 @@ class OAuthFulfillmentStage(StageView):
|
|||||||
def create_implicit_response(self, code: Optional[AuthorizationCode]) -> dict:
|
def create_implicit_response(self, code: Optional[AuthorizationCode]) -> dict:
|
||||||
"""Create implicit response's URL Fragment dictionary"""
|
"""Create implicit response's URL Fragment dictionary"""
|
||||||
query_fragment = {}
|
query_fragment = {}
|
||||||
auth_event = get_login_event(self.request)
|
|
||||||
|
|
||||||
now = timezone.now()
|
now = timezone.now()
|
||||||
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
|
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
|
||||||
@ -561,7 +526,6 @@ class OAuthFulfillmentStage(StageView):
|
|||||||
scope=self.params.scope,
|
scope=self.params.scope,
|
||||||
expires=access_token_expiry,
|
expires=access_token_expiry,
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
auth_time=auth_event.created if auth_event else now,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
id_token = IDToken.new(self.provider, token, self.request)
|
id_token = IDToken.new(self.provider, token, self.request)
|
||||||
@ -582,8 +546,6 @@ class OAuthFulfillmentStage(StageView):
|
|||||||
ResponseTypes.CODE_TOKEN,
|
ResponseTypes.CODE_TOKEN,
|
||||||
]:
|
]:
|
||||||
query_fragment["access_token"] = token.token
|
query_fragment["access_token"] = token.token
|
||||||
# Get at_hash of the current token and update the id_token
|
|
||||||
id_token.at_hash = token.at_hash
|
|
||||||
|
|
||||||
# Check if response_type must include id_token in the response.
|
# Check if response_type must include id_token in the response.
|
||||||
if self.params.response_type in [
|
if self.params.response_type in [
|
||||||
@ -592,6 +554,8 @@ class OAuthFulfillmentStage(StageView):
|
|||||||
ResponseTypes.CODE_ID_TOKEN,
|
ResponseTypes.CODE_ID_TOKEN,
|
||||||
ResponseTypes.CODE_ID_TOKEN_TOKEN,
|
ResponseTypes.CODE_ID_TOKEN_TOKEN,
|
||||||
]:
|
]:
|
||||||
|
# Get at_hash of the current token and update the id_token
|
||||||
|
id_token.at_hash = token.at_hash
|
||||||
query_fragment["id_token"] = self.provider.encode(id_token.to_dict())
|
query_fragment["id_token"] = self.provider.encode(id_token.to_dict())
|
||||||
token._id_token = dumps(id_token.to_dict())
|
token._id_token = dumps(id_token.to_dict())
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
||||||
from authentik.flows.exceptions import FlowNonApplicableException
|
|
||||||
from authentik.flows.models import in_memory_stage
|
from authentik.flows.models import in_memory_stage
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
|
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
|
||||||
from authentik.flows.stage import ChallengeStageView
|
from authentik.flows.stage import ChallengeStageView
|
||||||
@ -58,7 +57,6 @@ def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]:
|
|||||||
scope_descriptions = UserInfoView().get_scope_descriptions(token.scope)
|
scope_descriptions = UserInfoView().get_scope_descriptions(token.scope)
|
||||||
planner = FlowPlanner(token.provider.authorization_flow)
|
planner = FlowPlanner(token.provider.authorization_flow)
|
||||||
planner.allow_empty_flows = True
|
planner.allow_empty_flows = True
|
||||||
try:
|
|
||||||
plan = planner.plan(
|
plan = planner.plan(
|
||||||
request,
|
request,
|
||||||
{
|
{
|
||||||
@ -72,9 +70,6 @@ def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]:
|
|||||||
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
|
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except FlowNonApplicableException:
|
|
||||||
LOGGER.warning("Flow not applicable to user")
|
|
||||||
return None
|
|
||||||
plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
|
plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
|
||||||
request.session[SESSION_KEY_PLAN] = plan
|
request.session[SESSION_KEY_PLAN] = plan
|
||||||
return redirect_with_qs(
|
return redirect_with_qs(
|
||||||
@ -102,11 +97,7 @@ class DeviceEntryView(View):
|
|||||||
# Regardless, we start the planner and return to it
|
# Regardless, we start the planner and return to it
|
||||||
planner = FlowPlanner(device_flow)
|
planner = FlowPlanner(device_flow)
|
||||||
planner.allow_empty_flows = True
|
planner.allow_empty_flows = True
|
||||||
try:
|
|
||||||
plan = planner.plan(self.request)
|
plan = planner.plan(self.request)
|
||||||
except FlowNonApplicableException:
|
|
||||||
LOGGER.warning("Flow not applicable to user")
|
|
||||||
return HttpResponse(status=404)
|
|
||||||
plan.append_stage(in_memory_stage(OAuthDeviceCodeStage))
|
plan.append_stage(in_memory_stage(OAuthDeviceCodeStage))
|
||||||
|
|
||||||
self.request.session[SESSION_KEY_PLAN] = plan
|
self.request.session[SESSION_KEY_PLAN] = plan
|
||||||
|
@ -26,7 +26,6 @@ from authentik.core.models import (
|
|||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.events.signals import get_login_event
|
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION
|
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
@ -263,9 +262,8 @@ class TokenParams:
|
|||||||
Event.new(
|
Event.new(
|
||||||
action=EventAction.SUSPICIOUS_REQUEST,
|
action=EventAction.SUSPICIOUS_REQUEST,
|
||||||
message="Revoked refresh token was used",
|
message="Revoked refresh token was used",
|
||||||
token=self.refresh_token,
|
token=raw_token,
|
||||||
provider=self.refresh_token.provider,
|
).from_http(request)
|
||||||
).from_http(request, user=self.refresh_token.user)
|
|
||||||
raise TokenError("invalid_grant")
|
raise TokenError("invalid_grant")
|
||||||
|
|
||||||
def __post_init_client_credentials(self, request: HttpRequest):
|
def __post_init_client_credentials(self, request: HttpRequest):
|
||||||
@ -480,7 +478,6 @@ class TokenView(View):
|
|||||||
expires=access_token_expiry,
|
expires=access_token_expiry,
|
||||||
# Keep same scopes as previous token
|
# Keep same scopes as previous token
|
||||||
scope=self.params.authorization_code.scope,
|
scope=self.params.authorization_code.scope,
|
||||||
auth_time=self.params.authorization_code.auth_time,
|
|
||||||
)
|
)
|
||||||
access_token.id_token = IDToken.new(
|
access_token.id_token = IDToken.new(
|
||||||
self.provider,
|
self.provider,
|
||||||
@ -495,7 +492,6 @@ class TokenView(View):
|
|||||||
scope=self.params.authorization_code.scope,
|
scope=self.params.authorization_code.scope,
|
||||||
expires=refresh_token_expiry,
|
expires=refresh_token_expiry,
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
auth_time=self.params.authorization_code.auth_time,
|
|
||||||
)
|
)
|
||||||
id_token = IDToken.new(
|
id_token = IDToken.new(
|
||||||
self.provider,
|
self.provider,
|
||||||
@ -524,6 +520,7 @@ class TokenView(View):
|
|||||||
unauthorized_scopes = set(self.params.scope) - set(self.params.refresh_token.scope)
|
unauthorized_scopes = set(self.params.scope) - set(self.params.refresh_token.scope)
|
||||||
if unauthorized_scopes:
|
if unauthorized_scopes:
|
||||||
raise TokenError("invalid_scope")
|
raise TokenError("invalid_scope")
|
||||||
|
|
||||||
now = timezone.now()
|
now = timezone.now()
|
||||||
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
|
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
|
||||||
access_token = AccessToken(
|
access_token = AccessToken(
|
||||||
@ -532,7 +529,6 @@ class TokenView(View):
|
|||||||
expires=access_token_expiry,
|
expires=access_token_expiry,
|
||||||
# Keep same scopes as previous token
|
# Keep same scopes as previous token
|
||||||
scope=self.params.refresh_token.scope,
|
scope=self.params.refresh_token.scope,
|
||||||
auth_time=self.params.refresh_token.auth_time,
|
|
||||||
)
|
)
|
||||||
access_token.id_token = IDToken.new(
|
access_token.id_token = IDToken.new(
|
||||||
self.provider,
|
self.provider,
|
||||||
@ -547,7 +543,6 @@ class TokenView(View):
|
|||||||
scope=self.params.refresh_token.scope,
|
scope=self.params.refresh_token.scope,
|
||||||
expires=refresh_token_expiry,
|
expires=refresh_token_expiry,
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
auth_time=self.params.refresh_token.auth_time,
|
|
||||||
)
|
)
|
||||||
id_token = IDToken.new(
|
id_token = IDToken.new(
|
||||||
self.provider,
|
self.provider,
|
||||||
@ -582,7 +577,6 @@ class TokenView(View):
|
|||||||
user=self.params.user,
|
user=self.params.user,
|
||||||
expires=access_token_expiry,
|
expires=access_token_expiry,
|
||||||
scope=self.params.scope,
|
scope=self.params.scope,
|
||||||
auth_time=now,
|
|
||||||
)
|
)
|
||||||
access_token.id_token = IDToken.new(
|
access_token.id_token = IDToken.new(
|
||||||
self.provider,
|
self.provider,
|
||||||
@ -605,13 +599,11 @@ class TokenView(View):
|
|||||||
raise DeviceCodeError("authorization_pending")
|
raise DeviceCodeError("authorization_pending")
|
||||||
now = timezone.now()
|
now = timezone.now()
|
||||||
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
|
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
|
||||||
auth_event = get_login_event(self.request)
|
|
||||||
access_token = AccessToken(
|
access_token = AccessToken(
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=self.params.device_code.user,
|
user=self.params.device_code.user,
|
||||||
expires=access_token_expiry,
|
expires=access_token_expiry,
|
||||||
scope=self.params.device_code.scope,
|
scope=self.params.device_code.scope,
|
||||||
auth_time=auth_event.created if auth_event else now,
|
|
||||||
)
|
)
|
||||||
access_token.id_token = IDToken.new(
|
access_token.id_token = IDToken.new(
|
||||||
self.provider,
|
self.provider,
|
||||||
@ -626,7 +618,6 @@ class TokenView(View):
|
|||||||
scope=self.params.device_code.scope,
|
scope=self.params.device_code.scope,
|
||||||
expires=refresh_token_expiry,
|
expires=refresh_token_expiry,
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
auth_time=auth_event.created if auth_event else now,
|
|
||||||
)
|
)
|
||||||
id_token = IDToken.new(
|
id_token = IDToken.new(
|
||||||
self.provider,
|
self.provider,
|
||||||
|
@ -95,7 +95,6 @@ class ProxyProviderSerializer(ProviderSerializer):
|
|||||||
"refresh_token_validity",
|
"refresh_token_validity",
|
||||||
"outpost_set",
|
"outpost_set",
|
||||||
]
|
]
|
||||||
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class ProxyProviderViewSet(UsedByMixin, ModelViewSet):
|
class ProxyProviderViewSet(UsedByMixin, ModelViewSet):
|
||||||
|
@ -154,7 +154,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
|||||||
"url_slo_post",
|
"url_slo_post",
|
||||||
"url_slo_redirect",
|
"url_slo_redirect",
|
||||||
]
|
]
|
||||||
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class SAMLMetadataSerializer(PassiveSerializer):
|
class SAMLMetadataSerializer(PassiveSerializer):
|
||||||
|
@ -73,9 +73,9 @@ class AssertionProcessor:
|
|||||||
# https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions
|
# https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions
|
||||||
attribute_statement = Element(f"{{{NS_SAML_ASSERTION}}}AttributeStatement")
|
attribute_statement = Element(f"{{{NS_SAML_ASSERTION}}}AttributeStatement")
|
||||||
user = self.http_request.user
|
user = self.http_request.user
|
||||||
for mapping in SAMLPropertyMapping.objects.filter(provider=self.provider).order_by(
|
for mapping in self.provider.property_mappings.all().select_subclasses():
|
||||||
"saml_name"
|
if not isinstance(mapping, SAMLPropertyMapping):
|
||||||
):
|
continue
|
||||||
try:
|
try:
|
||||||
mapping: SAMLPropertyMapping
|
mapping: SAMLPropertyMapping
|
||||||
value = mapping.evaluate(
|
value = mapping.evaluate(
|
||||||
|
6
authentik/providers/saml/settings.py
Normal file
6
authentik/providers/saml/settings.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
"""saml provider settings"""
|
||||||
|
|
||||||
|
AUTHENTIK_PROVIDERS_SAML_PROCESSORS = [
|
||||||
|
"authentik.providers.saml.processors.generic",
|
||||||
|
"authentik.providers.saml.processors.salesforce",
|
||||||
|
]
|
@ -10,8 +10,8 @@ from authentik.core.models import Application
|
|||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
from authentik.flows.models import FlowDesignation
|
from authentik.flows.models import FlowDesignation
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.tests.utils import load_fixture
|
|
||||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||||
|
from authentik.providers.saml.tests.test_metadata import load_fixture
|
||||||
|
|
||||||
|
|
||||||
class TestSAMLProviderAPI(APITestCase):
|
class TestSAMLProviderAPI(APITestCase):
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""Test Service-Provider Metadata Parser"""
|
"""Test Service-Provider Metadata Parser"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import xmlsec
|
import xmlsec
|
||||||
from defusedxml.lxml import fromstring
|
from defusedxml.lxml import fromstring
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory, TestCase
|
||||||
@ -7,7 +9,6 @@ from lxml import etree # nosec
|
|||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
|
from authentik.core.tests.utils import create_test_cert, create_test_flow
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.tests.utils import load_fixture
|
|
||||||
from authentik.lib.xml import lxml_from_string
|
from authentik.lib.xml import lxml_from_string
|
||||||
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
|
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
|
||||||
from authentik.providers.saml.processors.metadata import MetadataProcessor
|
from authentik.providers.saml.processors.metadata import MetadataProcessor
|
||||||
@ -15,6 +16,12 @@ from authentik.providers.saml.processors.metadata_parser import ServiceProviderM
|
|||||||
from authentik.sources.saml.processors.constants import NS_MAP
|
from authentik.sources.saml.processors.constants import NS_MAP
|
||||||
|
|
||||||
|
|
||||||
|
def load_fixture(path: str, **kwargs) -> str:
|
||||||
|
"""Load fixture"""
|
||||||
|
with open(Path(__file__).resolve().parent / Path(path), "r", encoding="utf-8") as _fixture:
|
||||||
|
return _fixture.read()
|
||||||
|
|
||||||
|
|
||||||
class TestServiceProviderMetadataParser(TestCase):
|
class TestServiceProviderMetadataParser(TestCase):
|
||||||
"""Test ServiceProviderMetadataParser parsing and creation of SAML Provider"""
|
"""Test ServiceProviderMetadataParser parsing and creation of SAML Provider"""
|
||||||
|
|
||||||
@ -52,7 +59,7 @@ class TestServiceProviderMetadataParser(TestCase):
|
|||||||
request = self.factory.get("/")
|
request = self.factory.get("/")
|
||||||
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
|
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
|
||||||
|
|
||||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
|
schema = etree.XMLSchema(etree.parse("xml/saml-schema-metadata-2.0.xsd")) # nosec
|
||||||
self.assertTrue(schema.validate(metadata))
|
self.assertTrue(schema.validate(metadata))
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
|
@ -46,7 +46,7 @@ class TestSchema(TestCase):
|
|||||||
|
|
||||||
metadata = lxml_from_string(request)
|
metadata = lxml_from_string(request)
|
||||||
|
|
||||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
|
schema = etree.XMLSchema(etree.parse("xml/saml-schema-protocol-2.0.xsd")) # nosec
|
||||||
self.assertTrue(schema.validate(metadata))
|
self.assertTrue(schema.validate(metadata))
|
||||||
|
|
||||||
def test_response_schema(self):
|
def test_response_schema(self):
|
||||||
@ -67,5 +67,5 @@ class TestSchema(TestCase):
|
|||||||
|
|
||||||
metadata = lxml_from_string(response)
|
metadata = lxml_from_string(response)
|
||||||
|
|
||||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
|
schema = etree.XMLSchema(etree.parse("xml/saml-schema-protocol-2.0.xsd"))
|
||||||
self.assertTrue(schema.validate(metadata))
|
self.assertTrue(schema.validate(metadata))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""authentik SAML IDP Views"""
|
"""authentik SAML IDP Views"""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from django.http import Http404, HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
@ -11,7 +11,6 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.core.models import Application
|
from authentik.core.models import Application
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.flows.exceptions import FlowNonApplicableException
|
|
||||||
from authentik.flows.models import in_memory_stage
|
from authentik.flows.models import in_memory_stage
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
|
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
|
||||||
from authentik.flows.views.executor import SESSION_KEY_PLAN, SESSION_KEY_POST
|
from authentik.flows.views.executor import SESSION_KEY_PLAN, SESSION_KEY_POST
|
||||||
@ -61,7 +60,6 @@ class SAMLSSOView(PolicyAccessView):
|
|||||||
# Regardless, we start the planner and return to it
|
# Regardless, we start the planner and return to it
|
||||||
planner = FlowPlanner(self.provider.authorization_flow)
|
planner = FlowPlanner(self.provider.authorization_flow)
|
||||||
planner.allow_empty_flows = True
|
planner.allow_empty_flows = True
|
||||||
try:
|
|
||||||
plan = planner.plan(
|
plan = planner.plan(
|
||||||
request,
|
request,
|
||||||
{
|
{
|
||||||
@ -72,8 +70,6 @@ class SAMLSSOView(PolicyAccessView):
|
|||||||
PLAN_CONTEXT_CONSENT_PERMISSIONS: [],
|
PLAN_CONTEXT_CONSENT_PERMISSIONS: [],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except FlowNonApplicableException:
|
|
||||||
raise Http404
|
|
||||||
plan.append_stage(in_memory_stage(SAMLFlowFinalView))
|
plan.append_stage(in_memory_stage(SAMLFlowFinalView))
|
||||||
request.session[SESSION_KEY_PLAN] = plan
|
request.session[SESSION_KEY_PLAN] = plan
|
||||||
return redirect_with_qs(
|
return redirect_with_qs(
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
"""scim Property mappings API Views"""
|
|
||||||
from django_filters.filters import AllValuesMultipleFilter
|
|
||||||
from django_filters.filterset import FilterSet
|
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import extend_schema_field
|
|
||||||
from rest_framework.viewsets import ModelViewSet
|
|
||||||
|
|
||||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
|
||||||
from authentik.providers.scim.models import SCIMMapping
|
|
||||||
|
|
||||||
|
|
||||||
class SCIMMappingSerializer(PropertyMappingSerializer):
|
|
||||||
"""SCIMMapping Serializer"""
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
model = SCIMMapping
|
|
||||||
fields = PropertyMappingSerializer.Meta.fields
|
|
||||||
|
|
||||||
|
|
||||||
class SCIMMappingFilter(FilterSet):
|
|
||||||
"""Filter for SCIMMapping"""
|
|
||||||
|
|
||||||
managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
model = SCIMMapping
|
|
||||||
fields = "__all__"
|
|
||||||
|
|
||||||
|
|
||||||
class SCIMMappingViewSet(UsedByMixin, ModelViewSet):
|
|
||||||
"""SCIMMapping Viewset"""
|
|
||||||
|
|
||||||
queryset = SCIMMapping.objects.all()
|
|
||||||
serializer_class = SCIMMappingSerializer
|
|
||||||
filterset_class = SCIMMappingFilter
|
|
||||||
search_fields = ["name"]
|
|
||||||
ordering = ["name"]
|
|
@ -1,62 +0,0 @@
|
|||||||
"""SCIM Provider API Views"""
|
|
||||||
from django.utils.text import slugify
|
|
||||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
|
||||||
from rest_framework.decorators import action
|
|
||||||
from rest_framework.request import Request
|
|
||||||
from rest_framework.response import Response
|
|
||||||
from rest_framework.viewsets import ModelViewSet
|
|
||||||
|
|
||||||
from authentik.admin.api.tasks import TaskSerializer
|
|
||||||
from authentik.core.api.providers import ProviderSerializer
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
|
||||||
from authentik.events.monitored_tasks import TaskInfo
|
|
||||||
from authentik.providers.scim.models import SCIMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class SCIMProviderSerializer(ProviderSerializer):
|
|
||||||
"""SCIMProvider Serializer"""
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
model = SCIMProvider
|
|
||||||
fields = [
|
|
||||||
"pk",
|
|
||||||
"name",
|
|
||||||
"property_mappings",
|
|
||||||
"property_mappings_group",
|
|
||||||
"component",
|
|
||||||
"assigned_application_slug",
|
|
||||||
"assigned_application_name",
|
|
||||||
"verbose_name",
|
|
||||||
"verbose_name_plural",
|
|
||||||
"meta_model_name",
|
|
||||||
"url",
|
|
||||||
"token",
|
|
||||||
"exclude_users_service_account",
|
|
||||||
"filter_group",
|
|
||||||
]
|
|
||||||
extra_kwargs = {}
|
|
||||||
|
|
||||||
|
|
||||||
class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
|
|
||||||
"""SCIMProvider Viewset"""
|
|
||||||
|
|
||||||
queryset = SCIMProvider.objects.all()
|
|
||||||
serializer_class = SCIMProviderSerializer
|
|
||||||
filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"]
|
|
||||||
search_fields = ["name", "url"]
|
|
||||||
ordering = ["name", "url"]
|
|
||||||
|
|
||||||
@extend_schema(
|
|
||||||
responses={
|
|
||||||
200: TaskSerializer(),
|
|
||||||
404: OpenApiResponse(description="Task not found"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@action(methods=["GET"], detail=True, pagination_class=None, filter_backends=[])
|
|
||||||
def sync_status(self, request: Request, pk: int) -> Response:
|
|
||||||
"""Get provider's sync status"""
|
|
||||||
provider = self.get_object()
|
|
||||||
task = TaskInfo.by_name(f"scim_sync:{slugify(provider.name)}")
|
|
||||||
if not task:
|
|
||||||
return Response(status=404)
|
|
||||||
return Response(TaskSerializer(task).data)
|
|
@ -1,15 +0,0 @@
|
|||||||
"""authentik SCIM Provider app config"""
|
|
||||||
from authentik.blueprints.apps import ManagedAppConfig
|
|
||||||
|
|
||||||
|
|
||||||
class AuthentikProviderSCIMConfig(ManagedAppConfig):
|
|
||||||
"""authentik SCIM Provider app config"""
|
|
||||||
|
|
||||||
name = "authentik.providers.scim"
|
|
||||||
label = "authentik_providers_scim"
|
|
||||||
verbose_name = "authentik Providers.SCIM"
|
|
||||||
default = True
|
|
||||||
|
|
||||||
def reconcile_load_signals(self):
|
|
||||||
"""Load signals"""
|
|
||||||
self.import_module("authentik.providers.scim.signals")
|
|
@ -1,2 +0,0 @@
|
|||||||
"""SCIM constants"""
|
|
||||||
PAGE_SIZE = 100
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user