Compare commits
2 Commits
hack-close
...
version-20
Author | SHA1 | Date | |
---|---|---|---|
aba857753b | |||
022ff9b3a8 |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2023.6.1
|
||||
current_version = 2023.6.2
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
|
2
.github/workflows/ci-main.yml
vendored
2
.github/workflows/ci-main.yml
vendored
@ -112,7 +112,7 @@ jobs:
|
||||
- name: Setup authentik env
|
||||
uses: ./.github/actions/setup
|
||||
- name: Create k8s Kind Cluster
|
||||
uses: helm/kind-action@v1.8.0
|
||||
uses: helm/kind-action@v1.7.0
|
||||
- name: run integration
|
||||
run: |
|
||||
poetry run coverage run manage.py test tests/integration
|
||||
|
39
.github/workflows/translation-rename.yml
vendored
39
.github/workflows/translation-rename.yml
vendored
@ -1,39 +0,0 @@
|
||||
# Rename transifex pull requests to have a correct naming
|
||||
name: authentik-translation-transifex-rename
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, reopened]
|
||||
|
||||
jobs:
|
||||
rename_pr:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.pull_request.user.login == 'transifex-integration[bot]'}}
|
||||
steps:
|
||||
- id: generate_token
|
||||
uses: tibdex/github-app-token@v1
|
||||
with:
|
||||
app_id: ${{ secrets.GH_APP_ID }}
|
||||
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
|
||||
- name: Get current title
|
||||
id: title
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.generate_token.outputs.token }}
|
||||
run: |
|
||||
title=$(curl -q -L \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${GH_TOKEN}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} | jq -r .title)
|
||||
echo "title=${title}" >> "$GITHUB_OUTPUT"
|
||||
- name: Rename
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.generate_token.outputs.token }}
|
||||
run: |
|
||||
curl -L \
|
||||
-X PATCH \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${GH_TOKEN}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} \
|
||||
-d "{\"title\":\"translate: ${{ steps.title.outputs.title }}\"}"
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -204,4 +204,3 @@ data/
|
||||
|
||||
# Local Netlify folder
|
||||
.netlify
|
||||
.ruff_cache
|
||||
|
12
Dockerfile
12
Dockerfile
@ -31,7 +31,7 @@ RUN pip install --no-cache-dir poetry && \
|
||||
poetry export -f requirements.txt --dev --output requirements-dev.txt
|
||||
|
||||
# Stage 4: Build go proxy
|
||||
FROM docker.io/golang:1.20.6-bullseye AS go-builder
|
||||
FROM docker.io/golang:1.20.5-bullseye AS go-builder
|
||||
|
||||
WORKDIR /work
|
||||
|
||||
@ -47,18 +47,20 @@ COPY ./go.sum /work/go.sum
|
||||
RUN go build -o /work/authentik ./cmd/server/
|
||||
|
||||
# Stage 5: MaxMind GeoIP
|
||||
FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip
|
||||
FROM ghcr.io/maxmind/geoipupdate:v5.1 as geoip
|
||||
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City"
|
||||
ENV GEOIPUPDATE_VERBOSE="true"
|
||||
ENV GEOIPUPDATE_ACCOUNT_ID_FILE="/run/secrets/GEOIPUPDATE_ACCOUNT_ID"
|
||||
ENV GEOIPUPDATE_LICENSE_KEY_FILE="/run/secrets/GEOIPUPDATE_LICENSE_KEY"
|
||||
|
||||
USER root
|
||||
RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
--mount=type=secret,id=GEOIPUPDATE_LICENSE_KEY \
|
||||
mkdir -p /usr/share/GeoIP && \
|
||||
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
/bin/sh -c "\
|
||||
export GEOIPUPDATE_ACCOUNT_ID=$(cat /run/secrets/GEOIPUPDATE_ACCOUNT_ID); \
|
||||
export GEOIPUPDATE_LICENSE_KEY=$(cat /run/secrets/GEOIPUPDATE_LICENSE_KEY); \
|
||||
/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0 \
|
||||
"
|
||||
|
||||
# Stage 6: Run
|
||||
FROM docker.io/python:3.11.4-slim-bullseye AS final-image
|
||||
|
3
Makefile
3
Makefile
@ -145,8 +145,7 @@ web-lint-fix:
|
||||
|
||||
web-lint:
|
||||
cd web && npm run lint
|
||||
# TODO: The analyzer hasn't run correctly in awhile.
|
||||
# cd web && npm run lit-analyse
|
||||
cd web && npm run lit-analyse
|
||||
|
||||
web-check-compile:
|
||||
cd web && npm run tsc
|
||||
|
50
SECURITY.md
50
SECURITY.md
@ -1,48 +1,44 @@
|
||||
authentik takes security very seriously. We follow the rules of [responsible disclosure](https://en.wikipedia.org/wiki/Responsible_disclosure), and we urge our community to do so as well, instead of reporting vulnerabilities publicly. This allows us to patch the issue quickly, announce it's existence and release the fixed version.
|
||||
|
||||
## What authentik classifies as a CVE
|
||||
|
||||
CVE (Common Vulnerability and Exposure) is a system designed to aggregate all vulnerabilities. As such, a CVE will be issued when there is a either vulnerability or exposure. Per NIST, A vulnerability is:
|
||||
|
||||
“Weakness in an information system, system security procedures, internal controls, or implementation that could be exploited or triggered by a threat source.”
|
||||
|
||||
If it is determined that the issue does qualify as a CVE, a CVE number will be issued to the reporter from GitHub.
|
||||
|
||||
Even if the issue is not a CVE, we still greatly appreciate your help in hardening authentik.
|
||||
authentik takes security very seriously. We follow the rules of [responsible disclosure](https://en.wikipedia.org/wiki/Responsible_disclosure), and we urge our community to do so as well, instead of reporting vulnerabilities publicly. This allows us to patch the issue quickly, announce it's existence and release the fixed version.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
(.x being the latest patch release for each version)
|
||||
|
||||
| Version | Supported |
|
||||
| --- | --- |
|
||||
| 2023.5.x | ✅ |
|
||||
| 2023.6.x | ✅ |
|
||||
| Version | Supported |
|
||||
| --------- | ------------------ |
|
||||
| 2023.4.x | :white_check_mark: |
|
||||
| 2023.5.x | :white_check_mark: |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a vulnerability, send an email to [security@goauthentik.io](mailto:security@goauthentik.io). Be sure to include relevant information like which version you've found the issue in, instructions on how to reproduce the issue, and anything else that might make it easier for us to find the issue.
|
||||
To report a vulnerability, send an email to [security@goauthentik.io](mailto:security@goauthentik.io). Be sure to include relevant information like which version you've found the issue in, instructions on how to reproduce the issue, and anything else that might make it easier for us to find the bug.
|
||||
|
||||
## Severity levels
|
||||
## Criticality levels
|
||||
|
||||
authentik reserves the right to reclassify CVSS as necessary. To determine severity, we will use the CVSS calculator from NVD (https://nvd.nist.gov/vuln-metrics/cvss/v3-calculator). The calculated CVSS score will then be translated into one of the following categories:
|
||||
### High
|
||||
|
||||
| 0.0 | None |
|
||||
| 0.1 – 3.9 | Low |
|
||||
| 4.0 – 6.9 | Medium |
|
||||
| 7.0 – 8.9 | High |
|
||||
| 9.0 – 10.0 | Critical |
|
||||
- Authorization bypass
|
||||
- Circumvention of policies
|
||||
|
||||
### Moderate
|
||||
|
||||
- Denial-of-Service attacks
|
||||
|
||||
### Low
|
||||
|
||||
- Unvalidated redirects
|
||||
- Issues requiring uncommon setups
|
||||
|
||||
## Disclosure process
|
||||
|
||||
1. Report from Github or Issue is reported via Email as listed above.
|
||||
1. Issue is reported via Email as listed above.
|
||||
2. The authentik Security team will try to reproduce the issue and ask for more information if required.
|
||||
3. A severity level is assigned.
|
||||
3. A criticality level is assigned.
|
||||
4. A fix is created, and if possible tested by the issue reporter.
|
||||
5. The fix is backported to other supported versions, and if possible a workaround for other versions is created.
|
||||
6. An announcement is sent out with a fixed release date and severity level of the issue. The announcement will be sent at least 24 hours before the release of the security fix.
|
||||
6. An announcement is sent out with a fixed release date and criticality level of the issue. The announcement will be sent at least 24 hours before the release of the fix
|
||||
7. The fixed version is released for the supported versions.
|
||||
|
||||
## Getting security notifications
|
||||
|
||||
To get security notifications, subscribe to the mailing list [here](https://groups.google.com/g/authentik-security-announcements) or join the [discord](https://goauthentik.io/discord) server.
|
||||
To get security notifications, subscribe to the mailing list [here](https://groups.google.com/g/authentik-security-announcements) or join the [discord](https://goauthentik.io/discord) server.
|
||||
|
@ -2,7 +2,7 @@
|
||||
from os import environ
|
||||
from typing import Optional
|
||||
|
||||
__version__ = "2023.6.1"
|
||||
__version__ = "2023.6.2"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
@ -58,7 +58,7 @@ def clear_update_notifications():
|
||||
@prefill_task
|
||||
def update_latest_version(self: MonitoredTask):
|
||||
"""Update latest version info"""
|
||||
if CONFIG.get_bool("disable_update_check"):
|
||||
if CONFIG.y_bool("disable_update_check"):
|
||||
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
|
||||
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
|
||||
return
|
||||
|
@ -9,7 +9,7 @@ from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
from authentik.api.authentication import bearer_auth
|
||||
from authentik.blueprints.tests import reconcile_app
|
||||
from authentik.core.models import Token, TokenIntents, User, UserTypes
|
||||
from authentik.core.models import USER_ATTRIBUTE_SA, Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API
|
||||
@ -57,8 +57,8 @@ class TestAPIAuth(TestCase):
|
||||
@reconcile_app("authentik_outposts")
|
||||
def test_managed_outpost_success(self):
|
||||
"""Test managed outpost"""
|
||||
user: User = bearer_auth(f"Bearer {settings.SECRET_KEY}".encode())
|
||||
self.assertEqual(user.type, UserTypes.INTERNAL_SERVICE_ACCOUNT)
|
||||
user = bearer_auth(f"Bearer {settings.SECRET_KEY}".encode())
|
||||
self.assertEqual(user.attributes[USER_ATTRIBUTE_SA], True)
|
||||
|
||||
def test_jwt_valid(self):
|
||||
"""Test valid JWT"""
|
||||
|
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.dispatch import Signal
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.fields import (
|
||||
BooleanField,
|
||||
@ -22,8 +21,6 @@ from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
capabilities = Signal()
|
||||
|
||||
|
||||
class Capabilities(models.TextChoices):
|
||||
"""Define capabilities which influence which APIs can/should be used"""
|
||||
@ -70,15 +67,12 @@ class ConfigView(APIView):
|
||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||
if GEOIP_READER.enabled:
|
||||
caps.append(Capabilities.CAN_GEO_IP)
|
||||
if CONFIG.get_bool("impersonation"):
|
||||
if CONFIG.y_bool("impersonation"):
|
||||
caps.append(Capabilities.CAN_IMPERSONATE)
|
||||
if settings.DEBUG: # pragma: no cover
|
||||
caps.append(Capabilities.CAN_DEBUG)
|
||||
if "authentik.enterprise" in settings.INSTALLED_APPS:
|
||||
caps.append(Capabilities.IS_ENTERPRISE)
|
||||
for _, result in capabilities.send(sender=self):
|
||||
if result:
|
||||
caps.append(result)
|
||||
return caps
|
||||
|
||||
def get_config(self) -> ConfigSerializer:
|
||||
@ -86,17 +80,17 @@ class ConfigView(APIView):
|
||||
return ConfigSerializer(
|
||||
{
|
||||
"error_reporting": {
|
||||
"enabled": CONFIG.get("error_reporting.enabled"),
|
||||
"sentry_dsn": CONFIG.get("error_reporting.sentry_dsn"),
|
||||
"environment": CONFIG.get("error_reporting.environment"),
|
||||
"send_pii": CONFIG.get("error_reporting.send_pii"),
|
||||
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
|
||||
"enabled": CONFIG.y("error_reporting.enabled"),
|
||||
"sentry_dsn": CONFIG.y("error_reporting.sentry_dsn"),
|
||||
"environment": CONFIG.y("error_reporting.environment"),
|
||||
"send_pii": CONFIG.y("error_reporting.send_pii"),
|
||||
"traces_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.4)),
|
||||
},
|
||||
"capabilities": self.get_capabilities(),
|
||||
"cache_timeout": int(CONFIG.get("redis.cache_timeout")),
|
||||
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
|
||||
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
|
||||
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
|
||||
"cache_timeout": int(CONFIG.y("redis.cache_timeout")),
|
||||
"cache_timeout_flows": int(CONFIG.y("redis.cache_timeout_flows")),
|
||||
"cache_timeout_policies": int(CONFIG.y("redis.cache_timeout_policies")),
|
||||
"cache_timeout_reputation": int(CONFIG.y("redis.cache_timeout_reputation")),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -21,14 +21,9 @@ _other_urls = []
|
||||
for _authentik_app in get_apps():
|
||||
try:
|
||||
api_urls = import_module(f"{_authentik_app.name}.urls")
|
||||
except (ModuleNotFoundError, ImportError) as exc:
|
||||
LOGGER.warning("Could not import app's URLs", app_name=_authentik_app.name, exc=exc)
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
continue
|
||||
if not hasattr(api_urls, "api_urlpatterns"):
|
||||
LOGGER.debug(
|
||||
"App does not define API URLs",
|
||||
app_name=_authentik_app.name,
|
||||
)
|
||||
continue
|
||||
urls: list = getattr(api_urls, "api_urlpatterns")
|
||||
for url in urls:
|
||||
|
@ -30,7 +30,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
|
||||
return
|
||||
blueprint_file.seek(0)
|
||||
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first()
|
||||
rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
|
||||
rel_path = path.relative_to(Path(CONFIG.y("blueprints_dir")))
|
||||
meta = None
|
||||
if metadata:
|
||||
meta = from_dict(BlueprintMetadata, metadata)
|
||||
@ -55,7 +55,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||
Flow = apps.get_model("authentik_flows", "Flow")
|
||||
|
||||
db_alias = schema_editor.connection.alias
|
||||
for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True):
|
||||
for file in glob(f"{CONFIG.y('blueprints_dir')}/**/*.yaml", recursive=True):
|
||||
check_blueprint_v1_file(BlueprintInstance, Path(file))
|
||||
|
||||
for blueprint in BlueprintInstance.objects.using(db_alias).all():
|
||||
|
@ -82,7 +82,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
||||
def retrieve_file(self) -> str:
|
||||
"""Get blueprint from path"""
|
||||
try:
|
||||
base = Path(CONFIG.get("blueprints_dir"))
|
||||
base = Path(CONFIG.y("blueprints_dir"))
|
||||
full_path = base.joinpath(Path(self.path)).resolve()
|
||||
if not str(full_path).startswith(str(base.resolve())):
|
||||
raise BlueprintRetrievalFailed("Invalid blueprint path")
|
||||
|
@ -62,7 +62,7 @@ def start_blueprint_watcher():
|
||||
if _file_watcher_started:
|
||||
return
|
||||
observer = Observer()
|
||||
observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True)
|
||||
observer.schedule(BlueprintEventHandler(), CONFIG.y("blueprints_dir"), recursive=True)
|
||||
observer.start()
|
||||
_file_watcher_started = True
|
||||
|
||||
@ -80,7 +80,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
|
||||
blueprints_discovery.delay()
|
||||
if isinstance(event, FileModifiedEvent):
|
||||
path = Path(event.src_path)
|
||||
root = Path(CONFIG.get("blueprints_dir")).absolute()
|
||||
root = Path(CONFIG.y("blueprints_dir")).absolute()
|
||||
rel_path = str(path.relative_to(root))
|
||||
for instance in BlueprintInstance.objects.filter(path=rel_path):
|
||||
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
|
||||
@ -101,7 +101,7 @@ def blueprints_find_dict():
|
||||
def blueprints_find():
|
||||
"""Find blueprints and return valid ones"""
|
||||
blueprints = []
|
||||
root = Path(CONFIG.get("blueprints_dir"))
|
||||
root = Path(CONFIG.y("blueprints_dir"))
|
||||
for path in root.rglob("**/*.yaml"):
|
||||
# Check if any part in the path starts with a dot and assume a hidden file
|
||||
if any(part for part in path.parts if part.startswith(".")):
|
||||
|
@ -59,6 +59,7 @@ from authentik.core.middleware import (
|
||||
SESSION_KEY_IMPERSONATE_USER,
|
||||
)
|
||||
from authentik.core.models import (
|
||||
USER_ATTRIBUTE_SA,
|
||||
USER_ATTRIBUTE_TOKEN_EXPIRING,
|
||||
USER_PATH_SERVICE_ACCOUNT,
|
||||
AuthenticatedSession,
|
||||
@ -66,7 +67,6 @@ from authentik.core.models import (
|
||||
Token,
|
||||
TokenIntents,
|
||||
User,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
@ -147,18 +147,6 @@ class UserSerializer(ModelSerializer):
|
||||
raise ValidationError(_("No empty segments in user path allowed."))
|
||||
return path
|
||||
|
||||
def validate_type(self, user_type: str) -> str:
|
||||
"""Validate user type, internal_service_account is an internal value"""
|
||||
if (
|
||||
self.instance
|
||||
and self.instance.type == UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
and user_type != UserTypes.INTERNAL_SERVICE_ACCOUNT.value
|
||||
):
|
||||
raise ValidationError("Can't change internal service account to other user type.")
|
||||
if not self.instance and user_type == UserTypes.INTERNAL_SERVICE_ACCOUNT.value:
|
||||
raise ValidationError("Setting a user to internal service account is not allowed.")
|
||||
return user_type
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = [
|
||||
@ -175,7 +163,6 @@ class UserSerializer(ModelSerializer):
|
||||
"attributes",
|
||||
"uid",
|
||||
"path",
|
||||
"type",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"name": {"allow_blank": True},
|
||||
@ -224,7 +211,6 @@ class UserSelfSerializer(ModelSerializer):
|
||||
"avatar",
|
||||
"uid",
|
||||
"settings",
|
||||
"type",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"is_active": {"read_only": True},
|
||||
@ -343,7 +329,6 @@ class UsersFilter(FilterSet):
|
||||
"attributes",
|
||||
"groups_by_name",
|
||||
"groups_by_pk",
|
||||
"type",
|
||||
]
|
||||
|
||||
|
||||
@ -436,8 +421,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
user: User = User.objects.create(
|
||||
username=username,
|
||||
name=username,
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: expiring},
|
||||
attributes={USER_ATTRIBUTE_SA: True, USER_ATTRIBUTE_TOKEN_EXPIRING: expiring},
|
||||
path=USER_PATH_SERVICE_ACCOUNT,
|
||||
)
|
||||
user.set_unusable_password()
|
||||
@ -596,7 +580,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
@action(detail=True, methods=["POST"])
|
||||
def impersonate(self, request: Request, pk: int) -> Response:
|
||||
"""Impersonate a user"""
|
||||
if not CONFIG.get_bool("impersonation"):
|
||||
if not CONFIG.y_bool("impersonation"):
|
||||
LOGGER.debug("User attempted to impersonate", user=request.user)
|
||||
return Response(status=401)
|
||||
if not request.user.has_perm("impersonate"):
|
||||
|
@ -18,7 +18,7 @@ class Command(BaseCommand):
|
||||
|
||||
def handle(self, **options):
|
||||
close_old_connections()
|
||||
if CONFIG.get_bool("remote_debug"):
|
||||
if CONFIG.y_bool("remote_debug"):
|
||||
import debugpy
|
||||
|
||||
debugpy.listen(("0.0.0.0", 6900)) # nosec
|
||||
|
@ -1,43 +0,0 @@
|
||||
# Generated by Django 4.1.7 on 2023-05-21 11:44
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import migrations, models
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def migrate_user_type(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
User = apps.get_model("authentik_core", "User")
|
||||
|
||||
from authentik.core.models import UserTypes
|
||||
|
||||
for user in User.objects.using(db_alias).all():
|
||||
user.type = UserTypes.INTERNAL
|
||||
if "goauthentik.io/user/service-account" in user.attributes:
|
||||
user.type = UserTypes.SERVICE_ACCOUNT
|
||||
if "goauthentik.io/user/override-ips" in user.attributes:
|
||||
user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
user.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_core", "0029_provider_backchannel_applications_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="user",
|
||||
name="type",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("default", "Default"),
|
||||
("external", "External"),
|
||||
("service_account", "Service Account"),
|
||||
("internal_service_account", "Internal Service Account"),
|
||||
],
|
||||
default="default",
|
||||
),
|
||||
),
|
||||
migrations.RunPython(migrate_user_type),
|
||||
]
|
@ -1,41 +0,0 @@
|
||||
# Generated by Django 4.1.10 on 2023-07-21 12:54
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import migrations, models
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def migrate_user_type_v2(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
User = apps.get_model("authentik_core", "User")
|
||||
|
||||
from authentik.core.models import UserTypes
|
||||
|
||||
for user in User.objects.using(db_alias).all():
|
||||
if user.type != "default":
|
||||
continue
|
||||
user.type = UserTypes.INTERNAL
|
||||
user.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_core", "0030_user_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="user",
|
||||
name="type",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("internal", "Internal"),
|
||||
("external", "External"),
|
||||
("service_account", "Service Account"),
|
||||
("internal_service_account", "Internal Service Account"),
|
||||
],
|
||||
default="internal",
|
||||
),
|
||||
),
|
||||
migrations.RunPython(migrate_user_type_v2),
|
||||
]
|
@ -36,6 +36,7 @@ from authentik.root.install_id import get_install_id
|
||||
|
||||
LOGGER = get_logger()
|
||||
USER_ATTRIBUTE_DEBUG = "goauthentik.io/user/debug"
|
||||
USER_ATTRIBUTE_SA = "goauthentik.io/user/service-account"
|
||||
USER_ATTRIBUTE_GENERATED = "goauthentik.io/user/generated"
|
||||
USER_ATTRIBUTE_EXPIRES = "goauthentik.io/user/expires"
|
||||
USER_ATTRIBUTE_DELETE_ON_LOGOUT = "goauthentik.io/user/delete-on-logout"
|
||||
@ -44,6 +45,8 @@ USER_ATTRIBUTE_TOKEN_EXPIRING = "goauthentik.io/user/token-expires" # nosec
|
||||
USER_ATTRIBUTE_CHANGE_USERNAME = "goauthentik.io/user/can-change-username"
|
||||
USER_ATTRIBUTE_CHANGE_NAME = "goauthentik.io/user/can-change-name"
|
||||
USER_ATTRIBUTE_CHANGE_EMAIL = "goauthentik.io/user/can-change-email"
|
||||
USER_ATTRIBUTE_CAN_OVERRIDE_IP = "goauthentik.io/user/override-ips"
|
||||
|
||||
USER_PATH_SYSTEM_PREFIX = "goauthentik.io"
|
||||
USER_PATH_SERVICE_ACCOUNT = USER_PATH_SYSTEM_PREFIX + "/service-accounts"
|
||||
|
||||
@ -60,22 +63,7 @@ def default_token_key():
|
||||
"""Default token key"""
|
||||
# We use generate_id since the chars in the key should be easy
|
||||
# to use in Emails (for verification) and URLs (for recovery)
|
||||
return generate_id(int(CONFIG.get("default_token_length")))
|
||||
|
||||
|
||||
class UserTypes(models.TextChoices):
|
||||
"""User types, both for grouping, licensing and permissions in the case
|
||||
of the internal_service_account"""
|
||||
|
||||
INTERNAL = "internal"
|
||||
EXTERNAL = "external"
|
||||
|
||||
# User-created service accounts
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
|
||||
# Special user type for internally managed and created service
|
||||
# accounts, such as outpost users
|
||||
INTERNAL_SERVICE_ACCOUNT = "internal_service_account"
|
||||
return generate_id(int(CONFIG.y("default_token_length")))
|
||||
|
||||
|
||||
class Group(SerializerModel):
|
||||
@ -161,7 +149,6 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
||||
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
||||
name = models.TextField(help_text=_("User's display name."))
|
||||
path = models.TextField(default="users")
|
||||
type = models.TextField(choices=UserTypes.choices, default=UserTypes.INTERNAL)
|
||||
|
||||
sources = models.ManyToManyField("Source", through="UserSourceConnection")
|
||||
ak_groups = models.ManyToManyField("Group", related_name="users")
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""authentik core signals"""
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.contrib.auth.signals import user_logged_in, user_logged_out
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.core.cache import cache
|
||||
@ -8,13 +10,16 @@ from django.db.models.signals import post_save, pre_delete, pre_save
|
||||
from django.dispatch import receiver
|
||||
from django.http.request import HttpRequest
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession, BackchannelProvider, User
|
||||
from authentik.core.models import Application, AuthenticatedSession, BackchannelProvider
|
||||
|
||||
# Arguments: user: User, password: str
|
||||
password_changed = Signal()
|
||||
# Arguments: credentials: dict[str, any], request: HttpRequest, stage: Stage
|
||||
login_failed = Signal()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.core.models import User
|
||||
|
||||
|
||||
@receiver(post_save, sender=Application)
|
||||
def post_save_application(sender: type[Model], instance, created: bool, **_):
|
||||
@ -30,7 +35,7 @@ def post_save_application(sender: type[Model], instance, created: bool, **_):
|
||||
|
||||
|
||||
@receiver(user_logged_in)
|
||||
def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
|
||||
def user_logged_in_session(sender, request: HttpRequest, user: "User", **_):
|
||||
"""Create an AuthenticatedSession from request"""
|
||||
|
||||
session = AuthenticatedSession.from_request(request, user)
|
||||
@ -39,7 +44,7 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
|
||||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
|
||||
def user_logged_out_session(sender, request: HttpRequest, user: "User", **_):
|
||||
"""Delete AuthenticatedSession if it exists"""
|
||||
AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()
|
||||
|
||||
|
@ -8,11 +8,11 @@ from django.urls.base import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import (
|
||||
USER_ATTRIBUTE_SA,
|
||||
USER_ATTRIBUTE_TOKEN_EXPIRING,
|
||||
AuthenticatedSession,
|
||||
Token,
|
||||
User,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant
|
||||
from authentik.flows.models import FlowDesignation
|
||||
@ -141,8 +141,7 @@ class TestUsersAPI(APITestCase):
|
||||
|
||||
user_filter = User.objects.filter(
|
||||
username="test-sa",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True},
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True, USER_ATTRIBUTE_SA: True},
|
||||
)
|
||||
self.assertTrue(user_filter.exists())
|
||||
user: User = user_filter.first()
|
||||
@ -167,8 +166,7 @@ class TestUsersAPI(APITestCase):
|
||||
|
||||
user_filter = User.objects.filter(
|
||||
username="test-sa",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: False},
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: False, USER_ATTRIBUTE_SA: True},
|
||||
)
|
||||
self.assertTrue(user_filter.exists())
|
||||
user: User = user_filter.first()
|
||||
@ -194,8 +192,7 @@ class TestUsersAPI(APITestCase):
|
||||
|
||||
user_filter = User.objects.filter(
|
||||
username="test-sa",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True},
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True, USER_ATTRIBUTE_SA: True},
|
||||
)
|
||||
self.assertTrue(user_filter.exists())
|
||||
user: User = user_filter.first()
|
||||
@ -221,8 +218,7 @@ class TestUsersAPI(APITestCase):
|
||||
|
||||
user_filter = User.objects.filter(
|
||||
username="test-sa",
|
||||
type=UserTypes.SERVICE_ACCOUNT,
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True},
|
||||
attributes={USER_ATTRIBUTE_TOKEN_EXPIRING: True, USER_ATTRIBUTE_SA: True},
|
||||
)
|
||||
self.assertTrue(user_filter.exists())
|
||||
user: User = user_filter.first()
|
||||
|
@ -46,7 +46,7 @@ def certificate_discovery(self: MonitoredTask):
|
||||
certs = {}
|
||||
private_keys = {}
|
||||
discovered = 0
|
||||
for file in glob(CONFIG.get("cert_discovery_dir") + "/**", recursive=True):
|
||||
for file in glob(CONFIG.y("cert_discovery_dir") + "/**", recursive=True):
|
||||
path = Path(file)
|
||||
if not path.exists():
|
||||
continue
|
||||
|
@ -1,154 +0,0 @@
|
||||
"""Enterprise API Views"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField
|
||||
from rest_framework.permissions import IsAdminUser, IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.api.decorators import permission_required
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.models import User, UserTypes
|
||||
from authentik.enterprise.models import License, LicenseKey
|
||||
from authentik.root.install_id import get_install_id
|
||||
|
||||
|
||||
class LicenseSerializer(ModelSerializer):
|
||||
"""License Serializer"""
|
||||
|
||||
def validate_key(self, key: str) -> str:
|
||||
"""Validate the license key (install_id and signature)"""
|
||||
LicenseKey.validate(key)
|
||||
return key
|
||||
|
||||
class Meta:
|
||||
model = License
|
||||
fields = [
|
||||
"license_uuid",
|
||||
"name",
|
||||
"key",
|
||||
"expiry",
|
||||
"users",
|
||||
"external_users",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"name": {"read_only": True},
|
||||
"expiry": {"read_only": True},
|
||||
"users": {"read_only": True},
|
||||
"external_users": {"read_only": True},
|
||||
}
|
||||
|
||||
|
||||
class LicenseSummary(PassiveSerializer):
|
||||
"""Serializer for license status"""
|
||||
|
||||
users = IntegerField(required=True)
|
||||
external_users = IntegerField(required=True)
|
||||
valid = BooleanField()
|
||||
show_admin_warning = BooleanField()
|
||||
show_user_warning = BooleanField()
|
||||
read_only = BooleanField()
|
||||
latest_valid = DateTimeField()
|
||||
has_license = BooleanField()
|
||||
|
||||
|
||||
class LicenseForecastSerializer(PassiveSerializer):
|
||||
"""Serializer for license forecast"""
|
||||
|
||||
users = IntegerField(required=True)
|
||||
external_users = IntegerField(required=True)
|
||||
forecasted_users = IntegerField(required=True)
|
||||
forecasted_external_users = IntegerField(required=True)
|
||||
|
||||
|
||||
class LicenseViewSet(UsedByMixin, ModelViewSet):
|
||||
"""License Viewset"""
|
||||
|
||||
queryset = License.objects.all()
|
||||
serializer_class = LicenseSerializer
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
filterset_fields = ["name"]
|
||||
|
||||
@permission_required(None, ["authentik_enterprise.view_license"])
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
responses={
|
||||
200: inline_serializer("InstallIDSerializer", {"install_id": CharField(required=True)}),
|
||||
},
|
||||
)
|
||||
@action(detail=False, methods=["GET"], permission_classes=[IsAdminUser])
|
||||
def get_install_id(self, request: Request) -> Response:
|
||||
"""Get install_id"""
|
||||
return Response(
|
||||
data={
|
||||
"install_id": get_install_id(),
|
||||
}
|
||||
)
|
||||
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
responses={
|
||||
200: LicenseSummary(),
|
||||
},
|
||||
)
|
||||
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
|
||||
def summary(self, request: Request) -> Response:
|
||||
"""Get the total license status"""
|
||||
total = LicenseKey.get_total()
|
||||
last_valid = LicenseKey.last_valid_date()
|
||||
# TODO: move this to a different place?
|
||||
show_admin_warning = last_valid < now() - timedelta(weeks=2)
|
||||
show_user_warning = last_valid < now() - timedelta(weeks=4)
|
||||
read_only = last_valid < now() - timedelta(weeks=6)
|
||||
latest_valid = datetime.fromtimestamp(total.exp)
|
||||
response = LicenseSummary(
|
||||
data={
|
||||
"users": total.users,
|
||||
"external_users": total.external_users,
|
||||
"valid": total.is_valid(),
|
||||
"show_admin_warning": show_admin_warning,
|
||||
"show_user_warning": show_user_warning,
|
||||
"read_only": read_only,
|
||||
"latest_valid": latest_valid,
|
||||
"has_license": License.objects.all().count() > 0,
|
||||
}
|
||||
)
|
||||
response.is_valid(raise_exception=True)
|
||||
return Response(response.data)
|
||||
|
||||
@permission_required(None, ["authentik_enterprise.view_license"])
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
responses={
|
||||
200: LicenseForecastSerializer(),
|
||||
},
|
||||
)
|
||||
@action(detail=False, methods=["GET"])
|
||||
def forecast(self, request: Request) -> Response:
|
||||
"""Forecast how many users will be required in a year"""
|
||||
last_month = now() - timedelta(days=30)
|
||||
# Forecast for default users
|
||||
users_in_last_month = User.objects.filter(
|
||||
type=UserTypes.INTERNAL, date_joined__gte=last_month
|
||||
).count()
|
||||
# Forecast for external users
|
||||
external_in_last_month = LicenseKey.get_external_user_count()
|
||||
forecast_for_months = 12
|
||||
response = LicenseForecastSerializer(
|
||||
data={
|
||||
"users": LicenseKey.get_default_user_count(),
|
||||
"external_users": LicenseKey.get_external_user_count(),
|
||||
"forecasted_users": (users_in_last_month * forecast_for_months),
|
||||
"forecasted_external_users": (external_in_last_month * forecast_for_months),
|
||||
}
|
||||
)
|
||||
response.is_valid(raise_exception=True)
|
||||
return Response(response.data)
|
@ -9,7 +9,3 @@ class AuthentikEnterpriseConfig(ManagedAppConfig):
|
||||
label = "authentik_enterprise"
|
||||
verbose_name = "authentik Enterprise"
|
||||
default = True
|
||||
|
||||
def reconcile_load_enterprise_signals(self):
|
||||
"""Load enterprise signals"""
|
||||
self.import_module("authentik.enterprise.signals")
|
||||
|
@ -1,52 +0,0 @@
|
||||
# Generated by Django 4.1.10 on 2023-07-06 12:51
|
||||
|
||||
import uuid
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
import authentik.enterprise.models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = []
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="License",
|
||||
fields=[
|
||||
(
|
||||
"license_uuid",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("key", models.TextField(unique=True)),
|
||||
("name", models.TextField()),
|
||||
("expiry", models.DateTimeField()),
|
||||
("users", models.BigIntegerField()),
|
||||
("external_users", models.BigIntegerField()),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="LicenseUsage",
|
||||
fields=[
|
||||
("expiring", models.BooleanField(default=True)),
|
||||
("expires", models.DateTimeField(default=authentik.enterprise.models.usage_expiry)),
|
||||
(
|
||||
"usage_uuid",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("user_count", models.BigIntegerField()),
|
||||
("external_user_count", models.BigIntegerField()),
|
||||
("within_limits", models.BooleanField()),
|
||||
("record_date", models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
@ -1,185 +0,0 @@
|
||||
"""Enterprise models"""
|
||||
from base64 import b64decode
|
||||
from binascii import Error
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from time import mktime
|
||||
from uuid import uuid4
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
|
||||
from dacite import from_dict
|
||||
from django.db import models
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils.timezone import now
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from jwt import PyJWTError, decode, get_unverified_header
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.core.models import ExpiringModel, User, UserTypes
|
||||
from authentik.root.install_id import get_install_id
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_licensing_key() -> Certificate:
|
||||
"""Get Root CA PEM"""
|
||||
with open("authentik/enterprise/public.pem", "rb") as _key:
|
||||
return load_pem_x509_certificate(_key.read())
|
||||
|
||||
|
||||
def get_license_aud() -> str:
|
||||
"""Get the JWT audience field"""
|
||||
return f"enterprise.goauthentik.io/license/{get_install_id()}"
|
||||
|
||||
|
||||
class LicenseFlags(Enum):
|
||||
"""License flags"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LicenseKey:
|
||||
"""License JWT claims"""
|
||||
|
||||
aud: str
|
||||
exp: int
|
||||
|
||||
name: str
|
||||
users: int
|
||||
external_users: int
|
||||
flags: list[LicenseFlags] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def validate(jwt: str) -> "LicenseKey":
|
||||
"""Validate the license from a given JWT"""
|
||||
try:
|
||||
headers = get_unverified_header(jwt)
|
||||
except PyJWTError:
|
||||
raise ValidationError("Unable to verify license")
|
||||
x5c: list[str] = headers.get("x5c", [])
|
||||
if len(x5c) < 1:
|
||||
raise ValidationError("Unable to verify license")
|
||||
try:
|
||||
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
|
||||
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
|
||||
our_cert.verify_directly_issued_by(intermediate)
|
||||
intermediate.verify_directly_issued_by(get_licensing_key())
|
||||
except (InvalidSignature, TypeError, ValueError, Error):
|
||||
raise ValidationError("Unable to verify license")
|
||||
try:
|
||||
body = from_dict(
|
||||
LicenseKey,
|
||||
decode(
|
||||
jwt,
|
||||
our_cert.public_key(),
|
||||
algorithms=["ES512"],
|
||||
audience=get_license_aud(),
|
||||
),
|
||||
)
|
||||
except PyJWTError:
|
||||
raise ValidationError("Unable to verify license")
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def get_total() -> "LicenseKey":
|
||||
"""Get a summarized version of all (not expired) licenses"""
|
||||
active_licenses = License.objects.filter(expiry__gte=now())
|
||||
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
|
||||
for lic in active_licenses:
|
||||
total.users += lic.users
|
||||
total.external_users += lic.external_users
|
||||
exp_ts = int(mktime(lic.expiry.timetuple()))
|
||||
if total.exp == 0:
|
||||
total.exp = exp_ts
|
||||
if exp_ts <= total.exp:
|
||||
total.exp = exp_ts
|
||||
total.flags.extend(lic.status.flags)
|
||||
return total
|
||||
|
||||
@staticmethod
|
||||
def base_user_qs() -> QuerySet:
|
||||
"""Base query set for all users"""
|
||||
return User.objects.all().exclude(pk=get_anonymous_user().pk)
|
||||
|
||||
@staticmethod
|
||||
def get_default_user_count():
|
||||
"""Get current default user count"""
|
||||
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
|
||||
|
||||
@staticmethod
|
||||
def get_external_user_count():
|
||||
"""Get current external user count"""
|
||||
# Count since start of the month
|
||||
last_month = now().replace(day=1)
|
||||
return (
|
||||
LicenseKey.base_user_qs()
|
||||
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
|
||||
.count()
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the given license body covers all users
|
||||
|
||||
Only checks the current count, no historical data is checked"""
|
||||
default_users = self.get_default_user_count()
|
||||
if default_users > self.users:
|
||||
return False
|
||||
active_users = self.get_external_user_count()
|
||||
if active_users > self.external_users:
|
||||
return False
|
||||
return True
|
||||
|
||||
def record_usage(self):
|
||||
"""Capture the current validity status and metrics and save them"""
|
||||
LicenseUsage.objects.create(
|
||||
user_count=self.get_default_user_count(),
|
||||
external_user_count=self.get_external_user_count(),
|
||||
within_limits=self.is_valid(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def last_valid_date() -> datetime:
|
||||
"""Get the last date the license was valid"""
|
||||
usage: LicenseUsage = (
|
||||
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
|
||||
)
|
||||
if not usage:
|
||||
return now()
|
||||
return usage.record_date
|
||||
|
||||
|
||||
class License(models.Model):
|
||||
"""An authentik enterprise license"""
|
||||
|
||||
license_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
key = models.TextField(unique=True)
|
||||
|
||||
name = models.TextField()
|
||||
expiry = models.DateTimeField()
|
||||
users = models.BigIntegerField()
|
||||
external_users = models.BigIntegerField()
|
||||
|
||||
@property
|
||||
def status(self) -> LicenseKey:
|
||||
"""Get parsed license status"""
|
||||
return LicenseKey.validate(self.key)
|
||||
|
||||
|
||||
def usage_expiry():
|
||||
"""Keep license usage records for 3 months"""
|
||||
return now() + timedelta(days=30 * 3)
|
||||
|
||||
|
||||
class LicenseUsage(ExpiringModel):
|
||||
"""a single license usage record"""
|
||||
|
||||
expires = models.DateTimeField(default=usage_expiry)
|
||||
|
||||
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
|
||||
user_count = models.BigIntegerField()
|
||||
external_user_count = models.BigIntegerField()
|
||||
within_limits = models.BooleanField()
|
||||
|
||||
record_date = models.DateTimeField(auto_now_add=True)
|
@ -1,46 +0,0 @@
|
||||
"""Enterprise license policies"""
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
|
||||
from authentik.core.models import User, UserTypes
|
||||
from authentik.enterprise.models import LicenseKey
|
||||
from authentik.policies.models import Policy
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
from authentik.policies.views import PolicyAccessView
|
||||
|
||||
|
||||
class EnterprisePolicy(Policy):
|
||||
"""Check that a user is correctly licensed for the request"""
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
raise NotImplementedError
|
||||
|
||||
def passes(self, request: PolicyRequest) -> PolicyResult:
|
||||
if not LicenseKey.get_total().is_valid():
|
||||
return PolicyResult(False)
|
||||
if request.user.type != UserTypes.INTERNAL:
|
||||
return PolicyResult(False)
|
||||
return PolicyResult(True)
|
||||
|
||||
|
||||
class EnterprisePolicyAccessView(PolicyAccessView):
|
||||
"""PolicyAccessView which also checks enterprise licensing"""
|
||||
|
||||
def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
|
||||
user = user or self.request.user
|
||||
request = PolicyRequest(user)
|
||||
request.http_request = self.request
|
||||
result = super().user_has_access(user)
|
||||
enterprise_result = EnterprisePolicy().passes(request)
|
||||
if not enterprise_result.passing:
|
||||
return enterprise_result
|
||||
return result
|
||||
|
||||
def resolve_provider_application(self):
|
||||
raise NotImplementedError
|
@ -1,26 +0,0 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEdzCCA/6gAwIBAgIUQrj1jxn4q/BB38B2SwTrvGyrZLMwCgYIKoZIzj0EAwMw
|
||||
ge8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T
|
||||
YW4gRnJhbmNpc2NvMSQwIgYDVQQJExs1NDggTWFya2V0IFN0cmVldCBQbWIgNzAx
|
||||
NDgxDjAMBgNVBBETBTk0MTA0MSAwHgYDVQQKExdBdXRoZW50aWsgU2VjdXJpdHkg
|
||||
SW5jLjEcMBoGA1UECxMTRW50ZXJwcmlzZSBMaWNlbnNlczE9MDsGA1UEAxM0QXV0
|
||||
aGVudGlrIFNlY3VyaXR5IEluYy4gRW50ZXJwcmlzZSBMaWNlbnNpbmcgUm9vdCBY
|
||||
MTAgFw0yMzA3MDQxNzQ3NDBaGA8yMTIzMDYxMDE3NDgxMFowge8xCzAJBgNVBAYT
|
||||
AlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1TYW4gRnJhbmNpc2Nv
|
||||
MSQwIgYDVQQJExs1NDggTWFya2V0IFN0cmVldCBQbWIgNzAxNDgxDjAMBgNVBBET
|
||||
BTk0MTA0MSAwHgYDVQQKExdBdXRoZW50aWsgU2VjdXJpdHkgSW5jLjEcMBoGA1UE
|
||||
CxMTRW50ZXJwcmlzZSBMaWNlbnNlczE9MDsGA1UEAxM0QXV0aGVudGlrIFNlY3Vy
|
||||
aXR5IEluYy4gRW50ZXJwcmlzZSBMaWNlbnNpbmcgUm9vdCBYMTB2MBAGByqGSM49
|
||||
AgEGBSuBBAAiA2IABNbPJH6nDbSshpDsDHBRL0UcZVXWCK30txqcMKU+YFmLB6iR
|
||||
PJiHjHA8Z+5aP4eNH6onA5xqykQf65tvbFBA1LB/6HqMArU/tYVVQx4+o9hRBxF5
|
||||
RrzXucUg2br+RX8aa6OCAVUwggFRMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8E
|
||||
BTADAQH/MB0GA1UdDgQWBBRHpR3/ptPgN0yHVfUjyJOEmsPZqTAfBgNVHSMEGDAW
|
||||
gBRHpR3/ptPgN0yHVfUjyJOEmsPZqTCBoAYIKwYBBQUHAQEEgZMwgZAwRwYIKwYB
|
||||
BQUHMAGGO2h0dHBzOi8vdmF1bHQuY3VzdG9tZXJzLmdvYXV0aGVudGlrLmlvL3Yx
|
||||
L2xpY2Vuc2luZy1jYS9vY3NwMEUGCCsGAQUFBzAChjlodHRwczovL3ZhdWx0LmN1
|
||||
c3RvbWVycy5nb2F1dGhlbnRpay5pby92MS9saWNlbnNpbmctY2EvY2EwSwYDVR0f
|
||||
BEQwQjBAoD6gPIY6aHR0cHM6Ly92YXVsdC5jdXN0b21lcnMuZ29hdXRoZW50aWsu
|
||||
aW8vdjEvbGljZW5zaW5nLWNhL2NybDAKBggqhkjOPQQDAwNnADBkAjB0+YA1yjEO
|
||||
g43CCYUJXz9m9CNIkjOPUI0jO4UtvSj8j067TKRbX6IL/29HxPtQoYACME8eZHBJ
|
||||
Ljcog0oeBgjr4wK8bobgknr5wrm70rrNNpbSAjDvTvXMQeAShGgsftEquQ==
|
||||
-----END CERTIFICATE-----
|
@ -1,12 +1 @@
|
||||
"""Enterprise additional settings"""
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"enterprise_calculate_license": {
|
||||
"task": "authentik.enterprise.tasks.calculate_license",
|
||||
"schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/8"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +0,0 @@
|
||||
"""Enterprise signals"""
|
||||
from datetime import datetime
|
||||
|
||||
from django.db.models.signals import pre_save
|
||||
from django.dispatch import receiver
|
||||
from django.utils.timezone import get_current_timezone
|
||||
|
||||
from authentik.enterprise.models import License
|
||||
|
||||
|
||||
@receiver(pre_save, sender=License)
|
||||
def pre_save_license(sender: type[License], instance: License, **_):
|
||||
"""Extract data from license jwt and save it into model"""
|
||||
status = instance.status
|
||||
instance.name = status.name
|
||||
instance.users = status.users
|
||||
instance.external_users = status.external_users
|
||||
instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone())
|
@ -1,10 +0,0 @@
|
||||
"""Enterprise tasks"""
|
||||
from authentik.enterprise.models import LicenseKey
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def calculate_license():
|
||||
"""Calculate licensing status"""
|
||||
total = LicenseKey.get_total()
|
||||
total.record_usage()
|
@ -1,64 +0,0 @@
|
||||
"""Enterprise license tests"""
|
||||
from datetime import timedelta
|
||||
from time import mktime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.enterprise.models import License, LicenseKey
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
_exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
|
||||
|
||||
|
||||
class TestEnterpriseLicense(TestCase):
|
||||
"""Enterprise license tests"""
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=_exp,
|
||||
name=generate_id(),
|
||||
users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_valid(self):
|
||||
"""Check license verification"""
|
||||
lic = License.objects.create(key=generate_id())
|
||||
self.assertTrue(lic.status.is_valid())
|
||||
self.assertEqual(lic.users, 100)
|
||||
|
||||
def test_invalid(self):
|
||||
"""Test invalid license"""
|
||||
with self.assertRaises(ValidationError):
|
||||
License.objects.create(key=generate_id())
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=_exp,
|
||||
name=generate_id(),
|
||||
users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_valid_multiple(self):
|
||||
"""Check license verification"""
|
||||
lic = License.objects.create(key=generate_id())
|
||||
self.assertTrue(lic.status.is_valid())
|
||||
lic2 = License.objects.create(key=generate_id())
|
||||
self.assertTrue(lic2.status.is_valid())
|
||||
total = LicenseKey.get_total()
|
||||
self.assertEqual(total.users, 200)
|
||||
self.assertEqual(total.external_users, 200)
|
||||
self.assertEqual(total.exp, _exp)
|
||||
self.assertTrue(total.is_valid())
|
@ -1,7 +0,0 @@
|
||||
"""API URLs"""
|
||||
|
||||
from authentik.enterprise.api import LicenseViewSet
|
||||
|
||||
api_urlpatterns = [
|
||||
("enterprise/license", LicenseViewSet),
|
||||
]
|
@ -33,7 +33,7 @@ class GeoIPReader:
|
||||
|
||||
def __open(self):
|
||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||
path = CONFIG.get("geoip")
|
||||
path = CONFIG.y("geoip")
|
||||
if path == "" or not path:
|
||||
return
|
||||
try:
|
||||
@ -46,7 +46,7 @@ class GeoIPReader:
|
||||
def __check_expired(self):
|
||||
"""Check if the modification date of the GeoIP database has
|
||||
changed, and reload it if so"""
|
||||
path = CONFIG.get("geoip")
|
||||
path = CONFIG.y("geoip")
|
||||
try:
|
||||
mtime = stat(path).st_mtime
|
||||
diff = self.__last_mtime < mtime
|
||||
|
@ -76,20 +76,9 @@ class TaskInfo:
|
||||
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values()
|
||||
return cache.get(CACHE_KEY_PREFIX + name, None)
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
"""Get the full cache key with task name and UID"""
|
||||
key = CACHE_KEY_PREFIX + self.task_name
|
||||
if self.result.uid:
|
||||
uid_suffix = f":{self.result.uid}"
|
||||
key += uid_suffix
|
||||
if not self.task_name.endswith(uid_suffix):
|
||||
self.task_name += uid_suffix
|
||||
return key
|
||||
|
||||
def delete(self):
|
||||
"""Delete task info from cache"""
|
||||
return cache.delete(self.full_name)
|
||||
return cache.delete(CACHE_KEY_PREFIX + self.task_name)
|
||||
|
||||
def update_metrics(self):
|
||||
"""Update prometheus metrics"""
|
||||
@ -103,13 +92,17 @@ class TaskInfo:
|
||||
GAUGE_TASKS.labels(
|
||||
task_name=self.task_name.split(":")[0],
|
||||
task_uid=self.result.uid or "",
|
||||
status=self.result.status.name.lower(),
|
||||
status=self.result.status.value,
|
||||
).set(duration)
|
||||
|
||||
def save(self, timeout_hours=6):
|
||||
"""Save task into cache"""
|
||||
key = CACHE_KEY_PREFIX + self.task_name
|
||||
if self.result.uid:
|
||||
key += f":{self.result.uid}"
|
||||
self.task_name += f":{self.result.uid}"
|
||||
self.update_metrics()
|
||||
cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60)
|
||||
cache.set(key, self, timeout=timeout_hours * 60 * 60)
|
||||
|
||||
|
||||
class MonitoredTask(Task):
|
||||
|
@ -1,43 +0,0 @@
|
||||
"""Test Monitored tasks"""
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.events.monitored_tasks import MonitoredTask, TaskInfo, TaskResult, TaskResultStatus
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
class TestMonitoredTasks(TestCase):
|
||||
"""Test Monitored tasks"""
|
||||
|
||||
def test_failed_successful_remove_state(self):
|
||||
"""Test that a task with `save_on_success` set to `False` that failed saves
|
||||
a state, and upon successful completion will delete the state"""
|
||||
should_fail = True
|
||||
uid = generate_id()
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
)
|
||||
def test_task(self: MonitoredTask):
|
||||
self.save_on_success = False
|
||||
self.set_uid(uid)
|
||||
self.set_status(
|
||||
TaskResult(TaskResultStatus.ERROR if should_fail else TaskResultStatus.SUCCESSFUL)
|
||||
)
|
||||
|
||||
# First test successful run
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}"))
|
||||
|
||||
# Then test failed
|
||||
should_fail = True
|
||||
test_task.delay().get()
|
||||
info = TaskInfo.by_name(f"test_task:{uid}")
|
||||
self.assertEqual(info.result.status, TaskResultStatus.ERROR)
|
||||
|
||||
# Then after that, the state should be removed
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(TaskInfo.by_name(f"test_task:{uid}"))
|
@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
|
||||
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
||||
# was restored.
|
||||
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
|
||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_flows"))
|
||||
CACHE_PREFIX = "goauthentik.io/flows/planner/"
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ from authentik.flows.planner import FlowPlan, FlowPlanner
|
||||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
@ -84,6 +85,7 @@ class TestFlowExecutor(FlowTestCase):
|
||||
FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
@ -109,6 +111,7 @@ class TestFlowExecutor(FlowTestCase):
|
||||
denied_action=FlowDeniedAction.CONTINUE,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
@ -125,6 +128,7 @@ class TestFlowExecutor(FlowTestCase):
|
||||
FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
dest = "/unique-string"
|
||||
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||
response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}")
|
||||
@ -141,6 +145,7 @@ class TestFlowExecutor(FlowTestCase):
|
||||
FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
|
10
authentik/lib/apps.py
Normal file
10
authentik/lib/apps.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""authentik lib app config"""
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class AuthentikLibConfig(AppConfig):
|
||||
"""authentik lib app config"""
|
||||
|
||||
name = "authentik.lib"
|
||||
label = "authentik_lib"
|
||||
verbose_name = "authentik lib"
|
@ -175,7 +175,7 @@ def get_avatar(user: "User") -> str:
|
||||
"initials": avatar_mode_generated,
|
||||
"gravatar": avatar_mode_gravatar,
|
||||
}
|
||||
modes: str = CONFIG.get("avatars", "none")
|
||||
modes: str = CONFIG.y("avatars", "none")
|
||||
for mode in modes.split(","):
|
||||
avatar = None
|
||||
if mode in mode_map:
|
||||
|
@ -2,15 +2,13 @@
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from glob import glob
|
||||
from json import JSONEncoder, dumps, loads
|
||||
from json import dumps, loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from sys import argv, stderr
|
||||
from time import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
@ -34,44 +32,15 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
|
||||
return root
|
||||
|
||||
|
||||
@dataclass
|
||||
class Attr:
|
||||
"""Single configuration attribute"""
|
||||
|
||||
class Source(Enum):
|
||||
"""Sources a configuration attribute can come from, determines what should be done with
|
||||
Attr.source (and if it's set at all)"""
|
||||
|
||||
UNSPECIFIED = "unspecified"
|
||||
ENV = "env"
|
||||
CONFIG_FILE = "config_file"
|
||||
URI = "uri"
|
||||
|
||||
value: Any
|
||||
|
||||
source_type: Source = field(default=Source.UNSPECIFIED)
|
||||
|
||||
# depending on source_type, might contain the environment variable or the path
|
||||
# to the config file containing this change or the file containing this value
|
||||
source: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class AttrEncoder(JSONEncoder):
|
||||
"""JSON encoder that can deal with `Attr` classes"""
|
||||
|
||||
def default(self, o: Any) -> Any:
|
||||
if isinstance(o, Attr):
|
||||
return o.value
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""Search through SEARCH_PATHS and load configuration. Environment variables starting with
|
||||
`ENV_PREFIX` are also applied.
|
||||
|
||||
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
loaded_file = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.__config = {}
|
||||
base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve()
|
||||
@ -96,7 +65,6 @@ class ConfigLoader:
|
||||
# Update config with env file
|
||||
self.update_from_file(env_file)
|
||||
self.update_from_env()
|
||||
self.update(self.__config, kwargs)
|
||||
|
||||
def log(self, level: str, message: str, **kwargs):
|
||||
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
|
||||
@ -118,34 +86,22 @@ class ConfigLoader:
|
||||
else:
|
||||
if isinstance(value, str):
|
||||
value = self.parse_uri(value)
|
||||
elif isinstance(value, Attr) and isinstance(value.value, str):
|
||||
value = self.parse_uri(value.value)
|
||||
elif not isinstance(value, Attr):
|
||||
value = Attr(value)
|
||||
root[key] = value
|
||||
return root
|
||||
|
||||
def refresh(self, key: str):
|
||||
"""Update a single value"""
|
||||
attr: Attr = get_path_from_dict(self.raw, key)
|
||||
if attr.source_type != Attr.Source.URI:
|
||||
return
|
||||
attr.value = self.parse_uri(attr.source).value
|
||||
|
||||
def parse_uri(self, value: str) -> Attr:
|
||||
def parse_uri(self, value: str) -> str:
|
||||
"""Parse string values which start with a URI"""
|
||||
url = urlparse(value)
|
||||
parsed_value = value
|
||||
if url.scheme == "env":
|
||||
parsed_value = os.getenv(url.netloc, url.query)
|
||||
value = os.getenv(url.netloc, url.query)
|
||||
if url.scheme == "file":
|
||||
try:
|
||||
with open(url.path, "r", encoding="utf8") as _file:
|
||||
parsed_value = _file.read().strip()
|
||||
value = _file.read().strip()
|
||||
except OSError as exc:
|
||||
self.log("error", f"Failed to read config value from {url.path}: {exc}")
|
||||
parsed_value = url.query
|
||||
return Attr(parsed_value, Attr.Source.URI, value)
|
||||
value = url.query
|
||||
return value
|
||||
|
||||
def update_from_file(self, path: Path):
|
||||
"""Update config from file contents"""
|
||||
@ -154,6 +110,7 @@ class ConfigLoader:
|
||||
try:
|
||||
self.update(self.__config, yaml.safe_load(file))
|
||||
self.log("debug", "Loaded config", file=str(path))
|
||||
self.loaded_file.append(path)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ImproperlyConfigured from exc
|
||||
except PermissionError as exc:
|
||||
@ -164,6 +121,10 @@ class ConfigLoader:
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
def update_from_dict(self, update: dict):
|
||||
"""Update config from dict"""
|
||||
self.__config.update(update)
|
||||
|
||||
def update_from_env(self):
|
||||
"""Check environment variables"""
|
||||
outer = {}
|
||||
@ -184,7 +145,7 @@ class ConfigLoader:
|
||||
value = loads(value)
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
current_obj[dot_parts[-1]] = Attr(value, Attr.Source.ENV, key)
|
||||
current_obj[dot_parts[-1]] = value
|
||||
idx += 1
|
||||
if idx > 0:
|
||||
self.log("debug", "Loaded environment variables", count=idx)
|
||||
@ -193,32 +154,28 @@ class ConfigLoader:
|
||||
@contextmanager
|
||||
def patch(self, path: str, value: Any):
|
||||
"""Context manager for unittests to patch a value"""
|
||||
original_value = self.get(path)
|
||||
self.set(path, value)
|
||||
original_value = self.y(path)
|
||||
self.y_set(path, value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.set(path, original_value)
|
||||
self.y_set(path, original_value)
|
||||
|
||||
@property
|
||||
def raw(self) -> dict:
|
||||
"""Get raw config dictionary"""
|
||||
return self.__config
|
||||
|
||||
def get(self, path: str, default=None, sep=".") -> Any:
|
||||
# pylint: disable=invalid-name
|
||||
def y(self, path: str, default=None, sep=".") -> Any:
|
||||
"""Access attribute by using yaml path"""
|
||||
# Walk sub_dicts before parsing path
|
||||
root = self.raw
|
||||
# Walk each component of the path
|
||||
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
|
||||
return attr.value
|
||||
return get_path_from_dict(root, path, sep=sep, default=default)
|
||||
|
||||
def get_bool(self, path: str, default=False) -> bool:
|
||||
"""Wrapper for get that converts value into boolean"""
|
||||
return str(self.get(path, default)).lower() == "true"
|
||||
|
||||
def set(self, path: str, value: Any, sep="."):
|
||||
"""Set value using same syntax as get()"""
|
||||
def y_set(self, path: str, value: Any, sep="."):
|
||||
"""Set value using same syntax as y()"""
|
||||
# Walk sub_dicts before parsing path
|
||||
root = self.raw
|
||||
# Walk each component of the path
|
||||
@ -227,14 +184,17 @@ class ConfigLoader:
|
||||
if comp not in root:
|
||||
root[comp] = {}
|
||||
root = root.get(comp, {})
|
||||
root[path_parts[-1]] = Attr(value)
|
||||
root[path_parts[-1]] = value
|
||||
|
||||
def y_bool(self, path: str, default=False) -> bool:
|
||||
"""Wrapper for y that converts value into boolean"""
|
||||
return str(self.y(path, default)).lower() == "true"
|
||||
|
||||
|
||||
CONFIG = ConfigLoader()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(argv) < 2:
|
||||
print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
|
||||
print(dumps(CONFIG.raw, indent=4))
|
||||
else:
|
||||
print(CONFIG.get(argv[1]))
|
||||
print(CONFIG.y(argv[1]))
|
||||
|
@ -51,18 +51,18 @@ class SentryTransport(HttpTransport):
|
||||
|
||||
def sentry_init(**sentry_init_kwargs):
|
||||
"""Configure sentry SDK"""
|
||||
sentry_env = CONFIG.get("error_reporting.environment", "customer")
|
||||
sentry_env = CONFIG.y("error_reporting.environment", "customer")
|
||||
kwargs = {
|
||||
"environment": sentry_env,
|
||||
"send_default_pii": CONFIG.get_bool("error_reporting.send_pii", False),
|
||||
"send_default_pii": CONFIG.y_bool("error_reporting.send_pii", False),
|
||||
"_experiments": {
|
||||
"profiles_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.1)),
|
||||
"profiles_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.1)),
|
||||
},
|
||||
}
|
||||
kwargs.update(**sentry_init_kwargs)
|
||||
# pylint: disable=abstract-class-instantiated
|
||||
sentry_sdk_init(
|
||||
dsn=CONFIG.get("error_reporting.sentry_dsn"),
|
||||
dsn=CONFIG.y("error_reporting.sentry_dsn"),
|
||||
integrations=[
|
||||
ArgvIntegration(),
|
||||
StdlibIntegration(),
|
||||
@ -92,7 +92,7 @@ def traces_sampler(sampling_context: dict) -> float:
|
||||
return 0
|
||||
if _type == "websocket":
|
||||
return 0
|
||||
return float(CONFIG.get("error_reporting.sample_rate", 0.1))
|
||||
return float(CONFIG.y("error_reporting.sample_rate", 0.1))
|
||||
|
||||
|
||||
def before_send(event: dict, hint: dict) -> Optional[dict]:
|
||||
|
@ -16,23 +16,23 @@ class TestConfig(TestCase):
|
||||
config = ConfigLoader()
|
||||
environ[ENV_PREFIX + "_test__test"] = "bar"
|
||||
config.update_from_env()
|
||||
self.assertEqual(config.get("test.test"), "bar")
|
||||
self.assertEqual(config.y("test.test"), "bar")
|
||||
|
||||
def test_patch(self):
|
||||
"""Test patch decorator"""
|
||||
config = ConfigLoader()
|
||||
config.set("foo.bar", "bar")
|
||||
self.assertEqual(config.get("foo.bar"), "bar")
|
||||
config.y_set("foo.bar", "bar")
|
||||
self.assertEqual(config.y("foo.bar"), "bar")
|
||||
with config.patch("foo.bar", "baz"):
|
||||
self.assertEqual(config.get("foo.bar"), "baz")
|
||||
self.assertEqual(config.get("foo.bar"), "bar")
|
||||
self.assertEqual(config.y("foo.bar"), "baz")
|
||||
self.assertEqual(config.y("foo.bar"), "bar")
|
||||
|
||||
def test_uri_env(self):
|
||||
"""Test URI parsing (environment)"""
|
||||
config = ConfigLoader()
|
||||
environ["foo"] = "bar"
|
||||
self.assertEqual(config.parse_uri("env://foo").value, "bar")
|
||||
self.assertEqual(config.parse_uri("env://foo?bar").value, "bar")
|
||||
self.assertEqual(config.parse_uri("env://foo"), "bar")
|
||||
self.assertEqual(config.parse_uri("env://foo?bar"), "bar")
|
||||
|
||||
def test_uri_file(self):
|
||||
"""Test URI parsing (file load)"""
|
||||
@ -41,31 +41,11 @@ class TestConfig(TestCase):
|
||||
write(file, "foo".encode())
|
||||
_, file2_name = mkstemp()
|
||||
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
|
||||
self.assertEqual(config.parse_uri(f"file://{file_name}").value, "foo")
|
||||
self.assertEqual(config.parse_uri(f"file://{file2_name}?def").value, "def")
|
||||
self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo")
|
||||
self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def")
|
||||
unlink(file_name)
|
||||
unlink(file2_name)
|
||||
|
||||
def test_uri_file_update(self):
|
||||
"""Test URI parsing (file load and update)"""
|
||||
file, file_name = mkstemp()
|
||||
write(file, "foo".encode())
|
||||
config = ConfigLoader(file_test=f"file://{file_name}")
|
||||
self.assertEqual(config.get("file_test"), "foo")
|
||||
|
||||
# Update config file
|
||||
write(file, "bar".encode())
|
||||
config.refresh("file_test")
|
||||
self.assertEqual(config.get("file_test"), "foobar")
|
||||
|
||||
unlink(file_name)
|
||||
|
||||
def test_uri_env_full(self):
|
||||
"""Test URI set as env variable"""
|
||||
environ["AUTHENTIK_TEST_VAR"] = "file:///foo?bar"
|
||||
config = ConfigLoader()
|
||||
self.assertEqual(config.get("test_var"), "bar")
|
||||
|
||||
def test_file_update(self):
|
||||
"""Test update_from_file"""
|
||||
config = ConfigLoader()
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test HTTP Helpers"""
|
||||
from django.test import RequestFactory, TestCase
|
||||
|
||||
from authentik.core.models import Token, TokenIntents, UserTypes
|
||||
from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip
|
||||
from authentik.lib.views import bad_request_message
|
||||
@ -53,7 +53,7 @@ class TestHTTP(TestCase):
|
||||
)
|
||||
self.assertEqual(get_client_ip(request), "127.0.0.1")
|
||||
# Valid
|
||||
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
self.user.attributes[USER_ATTRIBUTE_CAN_OVERRIDE_IP] = True
|
||||
self.user.save()
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
|
@ -33,8 +33,9 @@ def _get_client_ip_from_meta(meta: dict[str, Any]) -> str:
|
||||
|
||||
def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
|
||||
"""Get the actual remote IP when set by an outpost. Only
|
||||
allowed when the request is authenticated, by an outpost internal service account"""
|
||||
from authentik.core.models import Token, TokenIntents, UserTypes
|
||||
allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set
|
||||
to outpost"""
|
||||
from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents
|
||||
|
||||
if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
|
||||
return None
|
||||
@ -50,7 +51,7 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
|
||||
LOGGER.warning("Attempted remote-ip override without token", fake_ip=fake_ip)
|
||||
return None
|
||||
user = token.user
|
||||
if user.type != UserTypes.INTERNAL_SERVICE_ACCOUNT:
|
||||
if not user.group_attributes(request).get(USER_ATTRIBUTE_CAN_OVERRIDE_IP, False):
|
||||
LOGGER.warning(
|
||||
"Remote-IP override: user doesn't have permission",
|
||||
user=user,
|
||||
|
@ -50,7 +50,7 @@ def get_env() -> str:
|
||||
"""Get environment in which authentik is currently running"""
|
||||
if "CI" in os.environ:
|
||||
return "ci"
|
||||
if CONFIG.get_bool("debug"):
|
||||
if CONFIG.y_bool("debug"):
|
||||
return "dev"
|
||||
if SERVICE_HOST_ENV_NAME in os.environ:
|
||||
return "kubernetes"
|
||||
|
@ -97,7 +97,7 @@ class BaseController:
|
||||
if self.outpost.config.container_image is not None:
|
||||
return self.outpost.config.container_image
|
||||
|
||||
image_name_template: str = CONFIG.get("outposts.container_image_base")
|
||||
image_name_template: str = CONFIG.y("outposts.container_image_base")
|
||||
return image_name_template % {
|
||||
"type": self.outpost.type,
|
||||
"version": __version__,
|
||||
|
@ -1,22 +1,16 @@
|
||||
"""Base Kubernetes Reconciler"""
|
||||
from dataclasses import asdict
|
||||
from json import dumps
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
|
||||
|
||||
from dacite.core import from_dict
|
||||
from django.utils.text import slugify
|
||||
from jsonpatch import JsonPatchConflict, JsonPatchException, JsonPatchTestFailed, apply_patch
|
||||
from kubernetes.client import ApiClient, V1ObjectMeta
|
||||
from kubernetes.client import V1ObjectMeta
|
||||
from kubernetes.client.exceptions import ApiException, OpenApiException
|
||||
from kubernetes.client.models.v1_deployment import V1Deployment
|
||||
from kubernetes.client.models.v1_pod import V1Pod
|
||||
from requests import Response
|
||||
from structlog.stdlib import get_logger
|
||||
from urllib3.exceptions import HTTPError
|
||||
|
||||
from authentik import __version__
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.controllers.base import ControllerException
|
||||
from authentik.outposts.controllers.k8s.triggers import NeedsRecreate, NeedsUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -40,23 +34,11 @@ class KubernetesObjectReconciler(Generic[T]):
|
||||
self.namespace = controller.outpost.config.kubernetes_namespace
|
||||
self.logger = get_logger().bind(type=self.__class__.__name__)
|
||||
|
||||
def get_patch(self):
|
||||
"""Get any patches that apply to this CRD"""
|
||||
patches = self.controller.outpost.config.kubernetes_json_patches
|
||||
if not patches:
|
||||
return None
|
||||
return patches.get(self.reconciler_name(), None)
|
||||
|
||||
@property
|
||||
def is_embedded(self) -> bool:
|
||||
"""Return true if the current outpost is embedded"""
|
||||
return self.controller.outpost.managed == MANAGED_OUTPOST
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
"""A name this reconciler is identified by in the configuration"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
"""Return true if this object should not be created/updated/deleted in this cluster"""
|
||||
@ -73,32 +55,6 @@ class KubernetesObjectReconciler(Generic[T]):
|
||||
}
|
||||
).lower()
|
||||
|
||||
def get_patched_reference_object(self) -> T:
|
||||
"""Get patched reference object"""
|
||||
reference = self.get_reference_object()
|
||||
patch = self.get_patch()
|
||||
try:
|
||||
json = ApiClient().sanitize_for_serialization(reference)
|
||||
# Custom objects will not be known to the clients openapi types
|
||||
except AttributeError:
|
||||
json = asdict(reference)
|
||||
try:
|
||||
ref = json
|
||||
if patch is not None:
|
||||
ref = apply_patch(json, patch)
|
||||
except (JsonPatchException, JsonPatchConflict, JsonPatchTestFailed) as exc:
|
||||
raise ControllerException(f"JSON Patch failed: {exc}") from exc
|
||||
mock_response = Response()
|
||||
mock_response.data = dumps(ref)
|
||||
|
||||
try:
|
||||
result = ApiClient().deserialize(mock_response, reference.__class__.__name__)
|
||||
# Custom objects will not be known to the clients openapi types
|
||||
except AttributeError:
|
||||
result = from_dict(reference.__class__, data=ref)
|
||||
|
||||
return result
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def up(self):
|
||||
"""Create object if it doesn't exist, update if needed or recreate if needed."""
|
||||
@ -106,7 +62,7 @@ class KubernetesObjectReconciler(Generic[T]):
|
||||
if self.noop:
|
||||
self.logger.debug("Object is noop")
|
||||
return
|
||||
reference = self.get_patched_reference_object()
|
||||
reference = self.get_reference_object()
|
||||
try:
|
||||
try:
|
||||
current = self.retrieve()
|
||||
@ -173,16 +129,6 @@ class KubernetesObjectReconciler(Generic[T]):
|
||||
if current.metadata.labels != reference.metadata.labels:
|
||||
raise NeedsUpdate()
|
||||
|
||||
patch = self.get_patch()
|
||||
if patch is not None:
|
||||
current_json = ApiClient().sanitize_for_serialization(current)
|
||||
|
||||
try:
|
||||
if apply_patch(current_json, patch) != current_json:
|
||||
raise NeedsUpdate()
|
||||
except (JsonPatchException, JsonPatchConflict, JsonPatchTestFailed) as exc:
|
||||
raise ControllerException(f"JSON Patch failed: {exc}") from exc
|
||||
|
||||
def create(self, reference: T):
|
||||
"""API Wrapper to create object"""
|
||||
raise NotImplementedError
|
||||
|
@ -43,10 +43,6 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
|
||||
self.api = AppsV1Api(controller.client)
|
||||
self.outpost = self.controller.outpost
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "deployment"
|
||||
|
||||
def reconcile(self, current: V1Deployment, reference: V1Deployment):
|
||||
compare_ports(
|
||||
current.spec.template.spec.containers[0].ports,
|
||||
|
@ -24,10 +24,6 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
|
||||
super().__init__(controller)
|
||||
self.api = CoreV1Api(controller.client)
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "secret"
|
||||
|
||||
def reconcile(self, current: V1Secret, reference: V1Secret):
|
||||
super().reconcile(current, reference)
|
||||
for key in reference.data.keys():
|
||||
|
@ -20,10 +20,6 @@ class ServiceReconciler(KubernetesObjectReconciler[V1Service]):
|
||||
super().__init__(controller)
|
||||
self.api = CoreV1Api(controller.client)
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "service"
|
||||
|
||||
def reconcile(self, current: V1Service, reference: V1Service):
|
||||
compare_ports(current.spec.ports, reference.spec.ports)
|
||||
# run the base reconcile last, as that will probably raise NeedsUpdate
|
||||
|
@ -71,10 +71,6 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
|
||||
self.api_ex = ApiextensionsV1Api(controller.client)
|
||||
self.api = CustomObjectsApi(controller.client)
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "prometheus servicemonitor"
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return (not self._crd_exists()) or (self.is_embedded)
|
||||
|
@ -64,19 +64,12 @@ class KubernetesController(BaseController):
|
||||
super().__init__(outpost, connection)
|
||||
self.client = KubernetesClient(connection)
|
||||
self.reconcilers = {
|
||||
SecretReconciler.reconciler_name(): SecretReconciler,
|
||||
DeploymentReconciler.reconciler_name(): DeploymentReconciler,
|
||||
ServiceReconciler.reconciler_name(): ServiceReconciler,
|
||||
PrometheusServiceMonitorReconciler.reconciler_name(): (
|
||||
PrometheusServiceMonitorReconciler
|
||||
),
|
||||
"secret": SecretReconciler,
|
||||
"deployment": DeploymentReconciler,
|
||||
"service": ServiceReconciler,
|
||||
"prometheus servicemonitor": PrometheusServiceMonitorReconciler,
|
||||
}
|
||||
self.reconcile_order = [
|
||||
SecretReconciler.reconciler_name(),
|
||||
DeploymentReconciler.reconciler_name(),
|
||||
ServiceReconciler.reconciler_name(),
|
||||
PrometheusServiceMonitorReconciler.reconciler_name(),
|
||||
]
|
||||
self.reconcile_order = ["secret", "deployment", "service", "prometheus servicemonitor"]
|
||||
|
||||
def up(self):
|
||||
try:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Outpost models"""
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, Optional
|
||||
from typing import Iterable, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from dacite.core import from_dict
|
||||
@ -20,12 +20,13 @@ from structlog.stdlib import get_logger
|
||||
from authentik import __version__, get_build_hash
|
||||
from authentik.blueprints.models import ManagedModel
|
||||
from authentik.core.models import (
|
||||
USER_ATTRIBUTE_CAN_OVERRIDE_IP,
|
||||
USER_ATTRIBUTE_SA,
|
||||
USER_PATH_SYSTEM_PREFIX,
|
||||
Provider,
|
||||
Token,
|
||||
TokenIntents,
|
||||
User,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.models import Event, EventAction
|
||||
@ -58,7 +59,7 @@ class OutpostConfig:
|
||||
authentik_host_insecure: bool = False
|
||||
authentik_host_browser: str = ""
|
||||
|
||||
log_level: str = CONFIG.get("log_level")
|
||||
log_level: str = CONFIG.y("log_level")
|
||||
object_naming_template: str = field(default="ak-outpost-%(name)s")
|
||||
|
||||
container_image: Optional[str] = field(default=None)
|
||||
@ -75,7 +76,6 @@ class OutpostConfig:
|
||||
kubernetes_service_type: str = field(default="ClusterIP")
|
||||
kubernetes_disabled_components: list[str] = field(default_factory=list)
|
||||
kubernetes_image_pull_secrets: list[str] = field(default_factory=list)
|
||||
kubernetes_json_patches: Optional[dict[str, list[dict[str, Any]]]] = field(default=None)
|
||||
|
||||
|
||||
class OutpostModel(Model):
|
||||
@ -346,7 +346,8 @@ class Outpost(SerializerModel, ManagedModel):
|
||||
user: User = User.objects.create(username=self.user_identifier)
|
||||
user.set_unusable_password()
|
||||
user_created = True
|
||||
user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
user.attributes[USER_ATTRIBUTE_SA] = True
|
||||
user.attributes[USER_ATTRIBUTE_CAN_OVERRIDE_IP] = True
|
||||
user.name = f"Outpost {self.name} Service-Account"
|
||||
user.path = USER_PATH_OUTPOSTS
|
||||
user.save()
|
||||
|
@ -256,7 +256,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
||||
def outpost_connection_discovery(self: MonitoredTask):
|
||||
"""Checks the local environment and create Service connections."""
|
||||
status = TaskResult(TaskResultStatus.SUCCESSFUL)
|
||||
if not CONFIG.get_bool("outposts.discover"):
|
||||
if not CONFIG.y_bool("outposts.discover"):
|
||||
status.messages.append("Outpost integration discovery is disabled")
|
||||
self.set_status(status)
|
||||
return
|
||||
|
@ -64,7 +64,7 @@ class PolicyEngine:
|
||||
self.use_cache = True
|
||||
self.__expected_result_count = 0
|
||||
|
||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
||||
def _iter_bindings(self) -> Iterator[PolicyBinding]:
|
||||
"""Make sure all Policies are their respective classes"""
|
||||
return (
|
||||
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
|
||||
@ -88,7 +88,7 @@ class PolicyEngine:
|
||||
span: Span
|
||||
span.set_data("pbm", self.__pbm)
|
||||
span.set_data("request", self.request)
|
||||
for binding in self.iterate_bindings():
|
||||
for binding in self._iter_bindings():
|
||||
self.__expected_result_count += 1
|
||||
|
||||
self._check_policy_type(binding)
|
||||
|
@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
|
||||
LOGGER = get_logger()
|
||||
|
||||
FORK_CTX = get_context("fork")
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
|
||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_policies"))
|
||||
PROCESS_CLASS = FORK_CTX.Process
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
|
||||
from authentik.stages.identification.signals import identification_failed
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
|
||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_reputation"))
|
||||
|
||||
|
||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
|
@ -6,12 +6,11 @@ from django.urls import reverse
|
||||
from jwt import decode
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, Token, TokenIntents, UserTypes
|
||||
from authentik.core.models import USER_ATTRIBUTE_SA, Application, Group, Token, TokenIntents
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.constants import (
|
||||
GRANT_TYPE_CLIENT_CREDENTIALS,
|
||||
GRANT_TYPE_PASSWORD,
|
||||
SCOPE_OPENID,
|
||||
SCOPE_OPENID_EMAIL,
|
||||
SCOPE_OPENID_PROFILE,
|
||||
@ -38,7 +37,7 @@ class TestTokenClientCredentials(OAuthTestCase):
|
||||
self.provider.property_mappings.set(ScopeMapping.objects.all())
|
||||
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)
|
||||
self.user = create_test_admin_user("sa")
|
||||
self.user.type = UserTypes.SERVICE_ACCOUNT
|
||||
self.user.attributes[USER_ATTRIBUTE_SA] = True
|
||||
self.user.save()
|
||||
self.token = Token.objects.create(
|
||||
identifier="sa-token",
|
||||
@ -151,28 +150,3 @@ class TestTokenClientCredentials(OAuthTestCase):
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
||||
def test_successful_password(self):
|
||||
"""test successful (password grant)"""
|
||||
response = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token"),
|
||||
{
|
||||
"grant_type": GRANT_TYPE_PASSWORD,
|
||||
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
|
||||
"client_id": self.provider.client_id,
|
||||
"username": "sa",
|
||||
"password": self.token.key,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["token_type"], TOKEN_TYPE)
|
||||
_, alg = self.provider.jwt_key
|
||||
jwt = decode(
|
||||
body["access_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
algorithms=[alg],
|
||||
audience=self.provider.client_id,
|
||||
)
|
||||
self.assertEqual(jwt["given_name"], self.user.name)
|
||||
self.assertEqual(jwt["preferred_username"], self.user.username)
|
||||
|
@ -46,7 +46,7 @@ class DeviceView(View):
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
throttle = AnonRateThrottle()
|
||||
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
|
||||
throttle.rate = CONFIG.y("throttle.providers.oauth2.device", "20/hour")
|
||||
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
||||
if not throttle.allow_request(request, self):
|
||||
return HttpResponse(status=429)
|
||||
|
@ -459,13 +459,13 @@ class TokenView(View):
|
||||
if self.params.grant_type == GRANT_TYPE_REFRESH_TOKEN:
|
||||
LOGGER.debug("Refreshing refresh token")
|
||||
return TokenResponse(self.create_refresh_response())
|
||||
if self.params.grant_type in [GRANT_TYPE_CLIENT_CREDENTIALS, GRANT_TYPE_PASSWORD]:
|
||||
LOGGER.debug("Client credentials/password grant")
|
||||
if self.params.grant_type == GRANT_TYPE_CLIENT_CREDENTIALS:
|
||||
LOGGER.debug("Client credentials grant")
|
||||
return TokenResponse(self.create_client_credentials_response())
|
||||
if self.params.grant_type == GRANT_TYPE_DEVICE_CODE:
|
||||
LOGGER.debug("Device code grant")
|
||||
return TokenResponse(self.create_device_code_response())
|
||||
raise TokenError("unsupported_grant_type")
|
||||
raise ValueError(f"Invalid grant_type: {self.params.grant_type}")
|
||||
except (TokenError, DeviceCodeError) as error:
|
||||
return TokenResponse(error.create_dict(), status=400)
|
||||
except UserAuthError as error:
|
||||
|
@ -31,10 +31,6 @@ class IngressReconciler(KubernetesObjectReconciler[V1Ingress]):
|
||||
super().__init__(controller)
|
||||
self.api = NetworkingV1Api(controller.client)
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "ingress"
|
||||
|
||||
def _check_annotations(self, reference: V1Ingress):
|
||||
"""Check that all annotations *we* set are correct"""
|
||||
for key, value in self.get_ingress_annotations().items():
|
||||
|
@ -17,28 +17,24 @@ class TraefikMiddlewareReconciler(KubernetesObjectReconciler):
|
||||
if not self.reconciler.crd_exists():
|
||||
self.reconciler = Traefik2MiddlewareReconciler(controller)
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "traefik middleware"
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return self.reconciler.noop
|
||||
|
||||
def reconcile(self, current: TraefikMiddleware, reference: TraefikMiddleware):
|
||||
return self.reconciler.reconcile(current, reference)
|
||||
return self.reconcile(current, reference)
|
||||
|
||||
def get_reference_object(self) -> TraefikMiddleware:
|
||||
return self.reconciler.get_reference_object()
|
||||
return self.get_reference_object()
|
||||
|
||||
def create(self, reference: TraefikMiddleware):
|
||||
return self.reconciler.create(reference)
|
||||
return self.create(reference)
|
||||
|
||||
def delete(self, reference: TraefikMiddleware):
|
||||
return self.reconciler.delete(reference)
|
||||
return self.delete(reference)
|
||||
|
||||
def retrieve(self) -> TraefikMiddleware:
|
||||
return self.reconciler.retrieve()
|
||||
return self.retrieve()
|
||||
|
||||
def update(self, current: TraefikMiddleware, reference: TraefikMiddleware):
|
||||
return self.reconciler.update(current, reference)
|
||||
return self.update(current, reference)
|
||||
|
@ -67,10 +67,6 @@ class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]
|
||||
self.crd_version = "v1alpha1"
|
||||
self.crd_plural = "middlewares"
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "traefik middleware"
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
if not ProxyProvider.objects.filter(
|
||||
|
@ -16,9 +16,7 @@ class ProxyKubernetesController(KubernetesController):
|
||||
DeploymentPort(9300, "http-metrics", "tcp"),
|
||||
DeploymentPort(9443, "https", "tcp"),
|
||||
]
|
||||
self.reconcilers[IngressReconciler.reconciler_name()] = IngressReconciler
|
||||
self.reconcilers[
|
||||
TraefikMiddlewareReconciler.reconciler_name()
|
||||
] = TraefikMiddlewareReconciler
|
||||
self.reconcile_order.append(IngressReconciler.reconciler_name())
|
||||
self.reconcile_order.append(TraefikMiddlewareReconciler.reconciler_name())
|
||||
self.reconcilers["ingress"] = IngressReconciler
|
||||
self.reconcilers["traefik middleware"] = TraefikMiddlewareReconciler
|
||||
self.reconcile_order.append("ingress")
|
||||
self.reconcile_order.append("traefik middleware")
|
||||
|
@ -1,11 +1,17 @@
|
||||
"""SCIM Provider models"""
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models import Q, QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
|
||||
from authentik.core.models import (
|
||||
USER_ATTRIBUTE_SA,
|
||||
BackchannelProvider,
|
||||
Group,
|
||||
PropertyMapping,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
class SCIMProvider(BackchannelProvider):
|
||||
@ -32,8 +38,17 @@ class SCIMProvider(BackchannelProvider):
|
||||
according to the provider's settings"""
|
||||
base = User.objects.all().exclude(pk=get_anonymous_user().pk)
|
||||
if self.exclude_users_service_account:
|
||||
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
|
||||
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
base = base.filter(
|
||||
Q(
|
||||
**{
|
||||
f"attributes__{USER_ATTRIBUTE_SA}__isnull": True,
|
||||
}
|
||||
)
|
||||
| Q(
|
||||
**{
|
||||
f"attributes__{USER_ATTRIBUTE_SA}": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
if self.filter_group:
|
||||
base = base.filter(ak_groups__in=[self.filter_group])
|
||||
|
@ -1,15 +0,0 @@
|
||||
"""authentik database backend"""
|
||||
from django_prometheus.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
"""database backend which supports rotating credentials"""
|
||||
|
||||
def get_connection_params(self):
|
||||
CONFIG.refresh("postgresql.password")
|
||||
conn_params = super().get_connection_params()
|
||||
conn_params["user"] = CONFIG.get("postgresql.user")
|
||||
conn_params["password"] = CONFIG.get("postgresql.password")
|
||||
return conn_params
|
@ -26,15 +26,15 @@ def get_install_id_raw():
|
||||
"""Get install_id without django loaded, this is required for the startup when we get
|
||||
the install_id but django isn't loaded yet and we can't use the function above."""
|
||||
conn = connect(
|
||||
dbname=CONFIG.get("postgresql.name"),
|
||||
user=CONFIG.get("postgresql.user"),
|
||||
password=CONFIG.get("postgresql.password"),
|
||||
host=CONFIG.get("postgresql.host"),
|
||||
port=int(CONFIG.get("postgresql.port")),
|
||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||
dbname=CONFIG.y("postgresql.name"),
|
||||
user=CONFIG.y("postgresql.user"),
|
||||
password=CONFIG.y("postgresql.password"),
|
||||
host=CONFIG.y("postgresql.host"),
|
||||
port=int(CONFIG.y("postgresql.port")),
|
||||
sslmode=CONFIG.y("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.y("postgresql.sslcert"),
|
||||
sslkey=CONFIG.y("postgresql.sslkey"),
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;")
|
||||
|
@ -10,8 +10,6 @@ from django.contrib.sessions.exceptions import SessionInterrupted
|
||||
from django.contrib.sessions.middleware import SessionMiddleware as UpstreamSessionMiddleware
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import HttpResponse
|
||||
from django.middleware.csrf import CSRF_SESSION_KEY
|
||||
from django.middleware.csrf import CsrfViewMiddleware as UpstreamCsrfViewMiddleware
|
||||
from django.utils.cache import patch_vary_headers
|
||||
from django.utils.http import http_date
|
||||
from jwt import PyJWTError, decode, encode
|
||||
@ -133,29 +131,6 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
||||
return response
|
||||
|
||||
|
||||
class CsrfViewMiddleware(UpstreamCsrfViewMiddleware):
|
||||
"""Dynamically set secure depending if the upstream connection is TLS or not"""
|
||||
|
||||
def _set_csrf_cookie(self, request: HttpRequest, response: HttpResponse):
|
||||
if settings.CSRF_USE_SESSIONS:
|
||||
if request.session.get(CSRF_SESSION_KEY) != request.META["CSRF_COOKIE"]:
|
||||
request.session[CSRF_SESSION_KEY] = request.META["CSRF_COOKIE"]
|
||||
else:
|
||||
secure = SessionMiddleware.is_secure(request)
|
||||
response.set_cookie(
|
||||
settings.CSRF_COOKIE_NAME,
|
||||
request.META["CSRF_COOKIE"],
|
||||
max_age=settings.CSRF_COOKIE_AGE,
|
||||
domain=settings.CSRF_COOKIE_DOMAIN,
|
||||
path=settings.CSRF_COOKIE_PATH,
|
||||
secure=secure,
|
||||
httponly=settings.CSRF_COOKIE_HTTPONLY,
|
||||
samesite=settings.CSRF_COOKIE_SAMESITE,
|
||||
)
|
||||
# Set the Vary header since content varies with the CSRF cookie.
|
||||
patch_vary_headers(response, ("Cookie",))
|
||||
|
||||
|
||||
class ChannelsLoggingMiddleware:
|
||||
"""Logging middleware for channels"""
|
||||
|
||||
|
@ -24,8 +24,8 @@ BASE_DIR = Path(__file__).absolute().parent.parent.parent
|
||||
STATICFILES_DIRS = [BASE_DIR / Path("web")]
|
||||
MEDIA_ROOT = BASE_DIR / Path("media")
|
||||
|
||||
DEBUG = CONFIG.get_bool("debug")
|
||||
SECRET_KEY = CONFIG.get("secret_key")
|
||||
DEBUG = CONFIG.y_bool("debug")
|
||||
SECRET_KEY = CONFIG.y("secret_key")
|
||||
|
||||
INTERNAL_IPS = ["127.0.0.1"]
|
||||
ALLOWED_HOSTS = ["*"]
|
||||
@ -40,7 +40,7 @@ CSRF_COOKIE_NAME = "authentik_csrf"
|
||||
CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF"
|
||||
LANGUAGE_COOKIE_NAME = "authentik_language"
|
||||
SESSION_COOKIE_NAME = "authentik_session"
|
||||
SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None)
|
||||
SESSION_COOKIE_DOMAIN = CONFIG.y("cookie_domain", None)
|
||||
|
||||
AUTHENTICATION_BACKENDS = [
|
||||
"django.contrib.auth.backends.ModelBackend",
|
||||
@ -66,6 +66,7 @@ INSTALLED_APPS = [
|
||||
"authentik.crypto",
|
||||
"authentik.events",
|
||||
"authentik.flows",
|
||||
"authentik.lib",
|
||||
"authentik.outposts",
|
||||
"authentik.policies.dummy",
|
||||
"authentik.policies.event_matcher",
|
||||
@ -145,7 +146,6 @@ SPECTACULAR_SETTINGS = {
|
||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
|
||||
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
|
||||
"UserTypeEnum": "authentik.core.models.UserTypes",
|
||||
},
|
||||
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
||||
"POSTPROCESSING_HOOKS": [
|
||||
@ -178,26 +178,26 @@ REST_FRAMEWORK = {
|
||||
"TEST_REQUEST_DEFAULT_FORMAT": "json",
|
||||
"DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"],
|
||||
"DEFAULT_THROTTLE_RATES": {
|
||||
"anon": CONFIG.get("throttle.default"),
|
||||
"anon": CONFIG.y("throttle.default"),
|
||||
},
|
||||
}
|
||||
|
||||
_redis_protocol_prefix = "redis://"
|
||||
_redis_celery_tls_requirements = ""
|
||||
if CONFIG.get_bool("redis.tls", False):
|
||||
if CONFIG.y_bool("redis.tls", False):
|
||||
_redis_protocol_prefix = "rediss://"
|
||||
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}"
|
||||
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}"
|
||||
_redis_url = (
|
||||
f"{_redis_protocol_prefix}:"
|
||||
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||
f"{int(CONFIG.get('redis.port'))}"
|
||||
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
|
||||
f"{int(CONFIG.y('redis.port'))}"
|
||||
)
|
||||
|
||||
CACHES = {
|
||||
"default": {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
|
||||
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
|
||||
"LOCATION": f"{_redis_url}/{CONFIG.y('redis.db')}",
|
||||
"TIMEOUT": int(CONFIG.y("redis.cache_timeout", 300)),
|
||||
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
||||
"KEY_PREFIX": "authentik_cache",
|
||||
}
|
||||
@ -225,7 +225,7 @@ MIDDLEWARE = [
|
||||
"authentik.events.middleware.AuditMiddleware",
|
||||
"django.middleware.security.SecurityMiddleware",
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
"authentik.root.middleware.CsrfViewMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
"authentik.core.middleware.ImpersonateMiddleware",
|
||||
@ -237,7 +237,7 @@ ROOT_URLCONF = "authentik.root.urls"
|
||||
TEMPLATES = [
|
||||
{
|
||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||
"DIRS": [CONFIG.get("email.template_dir")],
|
||||
"DIRS": [CONFIG.y("email.template_dir")],
|
||||
"APP_DIRS": True,
|
||||
"OPTIONS": {
|
||||
"context_processors": [
|
||||
@ -257,7 +257,7 @@ CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
||||
"CONFIG": {
|
||||
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
|
||||
"hosts": [f"{_redis_url}/{CONFIG.y('redis.db')}"],
|
||||
"prefix": "authentik_channels",
|
||||
},
|
||||
},
|
||||
@ -269,37 +269,34 @@ CHANNEL_LAYERS = {
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "authentik.root.db",
|
||||
"HOST": CONFIG.get("postgresql.host"),
|
||||
"NAME": CONFIG.get("postgresql.name"),
|
||||
"USER": CONFIG.get("postgresql.user"),
|
||||
"PASSWORD": CONFIG.get("postgresql.password"),
|
||||
"PORT": int(CONFIG.get("postgresql.port")),
|
||||
"SSLMODE": CONFIG.get("postgresql.sslmode"),
|
||||
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
|
||||
"SSLCERT": CONFIG.get("postgresql.sslcert"),
|
||||
"SSLKEY": CONFIG.get("postgresql.sslkey"),
|
||||
"ENGINE": "django_prometheus.db.backends.postgresql",
|
||||
"HOST": CONFIG.y("postgresql.host"),
|
||||
"NAME": CONFIG.y("postgresql.name"),
|
||||
"USER": CONFIG.y("postgresql.user"),
|
||||
"PASSWORD": CONFIG.y("postgresql.password"),
|
||||
"PORT": int(CONFIG.y("postgresql.port")),
|
||||
"SSLMODE": CONFIG.y("postgresql.sslmode"),
|
||||
"SSLROOTCERT": CONFIG.y("postgresql.sslrootcert"),
|
||||
"SSLCERT": CONFIG.y("postgresql.sslcert"),
|
||||
"SSLKEY": CONFIG.y("postgresql.sslkey"),
|
||||
}
|
||||
}
|
||||
|
||||
if CONFIG.get_bool("postgresql.use_pgbouncer", False):
|
||||
if CONFIG.y_bool("postgresql.use_pgbouncer", False):
|
||||
# https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors
|
||||
DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True
|
||||
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
|
||||
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
|
||||
|
||||
# Email
|
||||
# These values should never actually be used, emails are only sent from email stages, which
|
||||
# loads the config directly from CONFIG
|
||||
# See authentik/stages/email/models.py, line 105
|
||||
EMAIL_HOST = CONFIG.get("email.host")
|
||||
EMAIL_PORT = int(CONFIG.get("email.port"))
|
||||
EMAIL_HOST_USER = CONFIG.get("email.username")
|
||||
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
|
||||
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
|
||||
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
|
||||
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
|
||||
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
|
||||
EMAIL_HOST = CONFIG.y("email.host")
|
||||
EMAIL_PORT = int(CONFIG.y("email.port"))
|
||||
EMAIL_HOST_USER = CONFIG.y("email.username")
|
||||
EMAIL_HOST_PASSWORD = CONFIG.y("email.password")
|
||||
EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", False)
|
||||
EMAIL_USE_SSL = CONFIG.y_bool("email.use_ssl", False)
|
||||
EMAIL_TIMEOUT = int(CONFIG.y("email.timeout"))
|
||||
DEFAULT_FROM_EMAIL = CONFIG.y("email.from")
|
||||
SERVER_EMAIL = DEFAULT_FROM_EMAIL
|
||||
EMAIL_SUBJECT_PREFIX = "[authentik] "
|
||||
|
||||
@ -347,15 +344,15 @@ CELERY = {
|
||||
},
|
||||
"task_create_missing_queues": True,
|
||||
"task_default_queue": "authentik",
|
||||
"broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"broker_url": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"result_backend": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
|
||||
}
|
||||
|
||||
# Sentry integration
|
||||
env = get_env()
|
||||
_ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False)
|
||||
_ERROR_REPORTING = CONFIG.y_bool("error_reporting.enabled", False)
|
||||
if _ERROR_REPORTING:
|
||||
sentry_env = CONFIG.get("error_reporting.environment", "customer")
|
||||
sentry_env = CONFIG.y("error_reporting.environment", "customer")
|
||||
sentry_init()
|
||||
set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16])
|
||||
|
||||
@ -369,7 +366,7 @@ MEDIA_URL = "/media/"
|
||||
TEST = False
|
||||
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
|
||||
# We can't check TEST here as its set later by the test runner
|
||||
LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
|
||||
LOG_LEVEL = CONFIG.y("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
|
||||
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
|
||||
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
|
||||
# Additionally, the entire code uses debug as highest level so that would have to be re-written too
|
||||
|
@ -31,14 +31,14 @@ class PytestTestRunner: # pragma: no cover
|
||||
|
||||
settings.TEST = True
|
||||
settings.CELERY["task_always_eager"] = True
|
||||
CONFIG.set("avatars", "none")
|
||||
CONFIG.set("geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||
CONFIG.set("blueprints_dir", "./blueprints")
|
||||
CONFIG.set(
|
||||
CONFIG.y_set("avatars", "none")
|
||||
CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||
CONFIG.y_set("blueprints_dir", "./blueprints")
|
||||
CONFIG.y_set(
|
||||
"outposts.container_image_base",
|
||||
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
|
||||
)
|
||||
CONFIG.set("error_reporting.sample_rate", 0)
|
||||
CONFIG.y_set("error_reporting.sample_rate", 0)
|
||||
sentry_init(
|
||||
environment="testing",
|
||||
send_default_pii=True,
|
||||
|
@ -3,10 +3,7 @@ from django.core.management.base import BaseCommand
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.sources.ldap.models import LDAPSource
|
||||
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||
from authentik.sources.ldap.tasks import ldap_sync_paginator
|
||||
from authentik.sources.ldap.tasks import ldap_sync_single
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -23,10 +20,4 @@ class Command(BaseCommand):
|
||||
if not source:
|
||||
LOGGER.warning("Source does not exist", slug=source_slug)
|
||||
continue
|
||||
tasks = (
|
||||
ldap_sync_paginator(source, UserLDAPSynchronizer)
|
||||
+ ldap_sync_paginator(source, GroupLDAPSynchronizer)
|
||||
+ ldap_sync_paginator(source, MembershipLDAPSynchronizer)
|
||||
)
|
||||
for task in tasks:
|
||||
task()
|
||||
ldap_sync_single(source)
|
||||
|
@ -136,7 +136,7 @@ class LDAPSource(Source):
|
||||
chmod(private_key_file, 0o600)
|
||||
tls_kwargs["local_private_key_file"] = private_key_file
|
||||
tls_kwargs["local_certificate_file"] = certificate_file
|
||||
if ciphers := CONFIG.get("ldap.tls.ciphers", None):
|
||||
if ciphers := CONFIG.y("ldap.tls.ciphers", None):
|
||||
tls_kwargs["ciphers"] = ciphers.strip()
|
||||
if self.sni:
|
||||
tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip()
|
||||
|
@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
|
||||
types_only=False,
|
||||
get_operational_attributes=False,
|
||||
controls=None,
|
||||
paged_size=int(CONFIG.get("ldap.page_size", 50)),
|
||||
paged_size=int(CONFIG.y("ldap.page_size", 50)),
|
||||
paged_criticality=False,
|
||||
):
|
||||
"""Search in pages, returns each page"""
|
||||
|
@ -49,7 +49,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
uniq = self._flatten(attributes[self._source.object_uniqueness_field])
|
||||
try:
|
||||
defaults = self.build_user_properties(user_dn, **attributes)
|
||||
self._logger.debug("Writing user with attributes", **defaults)
|
||||
self._logger.debug("Creating user with attributes", **defaults)
|
||||
if "username" not in defaults:
|
||||
raise IntegrityError("Username was not set by propertymappings")
|
||||
ak_user, created = self.update_or_create_attributes(
|
||||
|
12
authentik/sources/ldap/sync/vendor/freeipa.py
vendored
12
authentik/sources/ldap/sync/vendor/freeipa.py
vendored
@ -20,7 +20,6 @@ class FreeIPA(BaseLDAPSynchronizer):
|
||||
|
||||
def sync(self, attributes: dict[str, Any], user: User, created: bool):
|
||||
self.check_pwd_last_set(attributes, user, created)
|
||||
self.check_nsaccountlock(attributes, user)
|
||||
|
||||
def check_pwd_last_set(self, attributes: dict[str, Any], user: User, created: bool):
|
||||
"""Check krbLastPwdChange"""
|
||||
@ -38,14 +37,3 @@ class FreeIPA(BaseLDAPSynchronizer):
|
||||
)
|
||||
user.set_unusable_password()
|
||||
user.save()
|
||||
|
||||
def check_nsaccountlock(self, attributes: dict[str, Any], user: User):
|
||||
"""https://www.port389.org/docs/389ds/howto/howto-account-inactivation.html"""
|
||||
# This is more of a 389-ds quirk rather than FreeIPA, but FreeIPA uses
|
||||
# 389-ds and this will trigger regardless
|
||||
if "nsaccountlock" not in attributes:
|
||||
return
|
||||
is_active = attributes.get("nsaccountlock", False)
|
||||
if is_active != user.is_active:
|
||||
user.is_active = is_active
|
||||
user.save()
|
||||
|
6
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
6
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
@ -78,7 +78,5 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
|
||||
# /useraccountcontrol-manipulate-account-properties
|
||||
uac_bit = attributes.get("userAccountControl", 512)
|
||||
uac = UserAccountControl(uac_bit)
|
||||
is_active = UserAccountControl.ACCOUNTDISABLE not in uac
|
||||
if is_active != user.is_active:
|
||||
user.is_active = is_active
|
||||
user.save()
|
||||
user.is_active = UserAccountControl.ACCOUNTDISABLE not in uac
|
||||
user.save()
|
||||
|
@ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
||||
signatures = []
|
||||
for page in sync_inst.get_objects():
|
||||
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
||||
cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")))
|
||||
cache.set(page_cache_key, page)
|
||||
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
|
||||
signatures.append(page_sync)
|
||||
return signatures
|
||||
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||
soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
|
||||
task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
|
||||
)
|
||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
||||
"""Synchronization of an LDAP Source"""
|
||||
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
|
||||
self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours"))
|
||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||
if not source:
|
||||
# Because the source couldn't be found, we don't have a UID
|
||||
@ -86,12 +86,6 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_k
|
||||
sync_inst: BaseLDAPSynchronizer = sync(source)
|
||||
page = cache.get(page_cache_key)
|
||||
if not page:
|
||||
error_message = (
|
||||
f"Could not find page in cache: {page_cache_key}. "
|
||||
+ "Try increasing ldap.task_timeout_hours"
|
||||
)
|
||||
LOGGER.warning(error_message)
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR, [error_message]))
|
||||
return
|
||||
cache.touch(page_cache_key)
|
||||
count = sync_inst.sync(page)
|
||||
|
@ -8,14 +8,12 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.monitored_tasks import TaskInfo, TaskResultStatus
|
||||
from authentik.lib.generators import generate_key
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
|
||||
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all
|
||||
from authentik.sources.ldap.tasks import ldap_sync_all
|
||||
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
|
||||
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
|
||||
|
||||
@ -35,14 +33,6 @@ class LDAPSyncTests(TestCase):
|
||||
additional_group_dn="ou=groups",
|
||||
)
|
||||
|
||||
def test_sync_missing_page(self):
|
||||
"""Test sync with missing page"""
|
||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
ldap_sync.delay(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo").get()
|
||||
status = TaskInfo.by_name("ldap_sync:ldap:users:foo")
|
||||
self.assertEqual(status.result.status, TaskResultStatus.ERROR)
|
||||
|
||||
def test_sync_error(self):
|
||||
"""Test user sync"""
|
||||
self.source.property_mappings.set(
|
||||
|
@ -13,7 +13,6 @@ from rest_framework.serializers import BaseSerializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -105,16 +104,7 @@ class EmailStage(Stage):
|
||||
def backend(self) -> BaseEmailBackend:
|
||||
"""Get fully configured Email Backend instance"""
|
||||
if self.use_global_settings:
|
||||
CONFIG.refresh("email.password")
|
||||
return self.backend_class(
|
||||
host=CONFIG.get("email.host"),
|
||||
port=int(CONFIG.get("email.port")),
|
||||
username=CONFIG.get("email.username"),
|
||||
password=CONFIG.get("email.password"),
|
||||
use_tls=CONFIG.get_bool("email.use_tls", False),
|
||||
use_ssl=CONFIG.get_bool("email.use_ssl", False),
|
||||
timeout=int(CONFIG.get("email.timeout")),
|
||||
)
|
||||
return self.backend_class()
|
||||
return self.backend_class(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
|
@ -12,7 +12,7 @@ from rest_framework.fields import CharField
|
||||
from rest_framework.serializers import ValidationError
|
||||
|
||||
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
||||
from authentik.flows.models import FlowToken
|
||||
from authentik.flows.models import FlowDesignation, FlowToken
|
||||
from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.flows.views.executor import QS_KEY_TOKEN
|
||||
@ -82,6 +82,11 @@ class EmailStageView(ChallengeStageView):
|
||||
"""Helper function that sends the actual email. Implies that you've
|
||||
already checked that there is a pending user."""
|
||||
pending_user = self.get_pending_user()
|
||||
if not pending_user.pk and self.executor.flow.designation == FlowDesignation.RECOVERY:
|
||||
# Pending user does not have a primary key, and we're in a recovery flow,
|
||||
# which means the user entered an invalid identifier, so we pretend to send the
|
||||
# email, to not disclose if the user exists
|
||||
return
|
||||
email = self.executor.plan.context.get(PLAN_CONTEXT_EMAIL_OVERRIDE, None)
|
||||
if not email:
|
||||
email = pending_user.email
|
||||
|
@ -5,18 +5,20 @@ from unittest.mock import MagicMock, PropertyMock, patch
|
||||
from django.core import mail
|
||||
from django.core.mail.backends.locmem import EmailBackend
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.markers import StageMarker
|
||||
from authentik.flows.models import FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.stages.email.models import EmailStage
|
||||
|
||||
|
||||
class TestEmailStageSending(APITestCase):
|
||||
class TestEmailStageSending(FlowTestCase):
|
||||
"""Email tests"""
|
||||
|
||||
def setUp(self):
|
||||
@ -44,6 +46,13 @@ class TestEmailStageSending(APITestCase):
|
||||
):
|
||||
response = self.client.post(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertStageResponse(
|
||||
response,
|
||||
self.flow,
|
||||
response_errors={
|
||||
"non_field_errors": [{"string": "email-sent", "code": "email-sent"}]
|
||||
},
|
||||
)
|
||||
self.assertEqual(len(mail.outbox), 1)
|
||||
self.assertEqual(mail.outbox[0].subject, "authentik")
|
||||
events = Event.objects.filter(action=EventAction.EMAIL_SENT)
|
||||
@ -54,6 +63,32 @@ class TestEmailStageSending(APITestCase):
|
||||
self.assertEqual(event.context["to_email"], [self.user.email])
|
||||
self.assertEqual(event.context["from_email"], "system@authentik.local")
|
||||
|
||||
def test_pending_fake_user(self):
|
||||
"""Test with pending (fake) user"""
|
||||
self.flow.designation = FlowDesignation.RECOVERY
|
||||
self.flow.save()
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = User(username=generate_id())
|
||||
session = self.client.session
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
|
||||
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
with patch(
|
||||
"authentik.stages.email.models.EmailStage.backend_class",
|
||||
PropertyMock(return_value=EmailBackend),
|
||||
):
|
||||
response = self.client.post(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertStageResponse(
|
||||
response,
|
||||
self.flow,
|
||||
response_errors={
|
||||
"non_field_errors": [{"string": "email-sent", "code": "email-sent"}]
|
||||
},
|
||||
)
|
||||
self.assertEqual(len(mail.outbox), 0)
|
||||
|
||||
def test_send_error(self):
|
||||
"""Test error during sending (sending will be retried)"""
|
||||
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
|
||||
|
@ -13,7 +13,6 @@ from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.stages.email.models import EmailStage
|
||||
from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE
|
||||
|
||||
@ -121,7 +120,7 @@ class TestEmailStage(FlowTestCase):
|
||||
def test_use_global_settings(self):
|
||||
"""Test use_global_settings"""
|
||||
host = "some-unique-string"
|
||||
with CONFIG.patch("email.host", host):
|
||||
with self.settings(EMAIL_HOST=host):
|
||||
self.assertEqual(EmailStage(use_global_settings=True).backend.host, host)
|
||||
|
||||
def test_token(self):
|
||||
|
@ -118,8 +118,12 @@ class IdentificationChallengeResponse(ChallengeResponse):
|
||||
username=uid_field,
|
||||
email=uid_field,
|
||||
)
|
||||
self.pre_user = self.stage.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
||||
if not current_stage.show_matched_user:
|
||||
self.stage.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = uid_field
|
||||
if self.stage.executor.flow.designation == FlowDesignation.RECOVERY:
|
||||
# When used in a recovery flow, always continue to not disclose if a user exists
|
||||
return attrs
|
||||
raise ValidationError("Failed to authenticate.")
|
||||
self.pre_user = pre_user
|
||||
if not current_stage.password_stage:
|
||||
|
@ -188,7 +188,7 @@ class TestIdentificationStage(FlowTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_recovery_flow(self):
|
||||
def test_link_recovery_flow(self):
|
||||
"""Test that recovery flow is linked correctly"""
|
||||
flow = create_test_flow()
|
||||
self.stage.recovery_flow = flow
|
||||
@ -226,6 +226,38 @@ class TestIdentificationStage(FlowTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_recovery_flow_invalid_user(self):
|
||||
"""Test that an invalid user can proceed in a recovery flow"""
|
||||
self.flow.designation = FlowDesignation.RECOVERY
|
||||
self.flow.save()
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
|
||||
)
|
||||
self.assertStageResponse(
|
||||
response,
|
||||
self.flow,
|
||||
component="ak-stage-identification",
|
||||
user_fields=["email"],
|
||||
password_fields=False,
|
||||
show_source_labels=False,
|
||||
primary_action="Continue",
|
||||
sources=[
|
||||
{
|
||||
"challenge": {
|
||||
"component": "xak-flow-redirect",
|
||||
"to": "/source/oauth/login/test/",
|
||||
"type": ChallengeTypes.REDIRECT.value,
|
||||
},
|
||||
"icon_url": "/static/authentik/sources/default.svg",
|
||||
"name": "test",
|
||||
}
|
||||
],
|
||||
)
|
||||
form_data = {"uid_field": generate_id()}
|
||||
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
response = self.client.post(url, form_data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_api_validate(self):
|
||||
"""Test API validation"""
|
||||
self.assertTrue(
|
||||
|
@ -179,7 +179,7 @@ class ListPolicyEngine(PolicyEngine):
|
||||
self.__list = policies
|
||||
self.use_cache = False
|
||||
|
||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
||||
def _iter_bindings(self) -> Iterator[PolicyBinding]:
|
||||
for policy in self.__list:
|
||||
yield PolicyBinding(
|
||||
policy=policy,
|
||||
|
@ -78,7 +78,7 @@ class CurrentTenantSerializer(PassiveSerializer):
|
||||
ui_footer_links = ListField(
|
||||
child=FooterLinkSerializer(),
|
||||
read_only=True,
|
||||
default=CONFIG.get("footer_links", []),
|
||||
default=CONFIG.y("footer_links", []),
|
||||
)
|
||||
ui_theme = ChoiceField(
|
||||
choices=Themes.choices,
|
||||
|
@ -24,7 +24,7 @@ class TestTenants(APITestCase):
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "authentik",
|
||||
"matched_domain": tenant.domain,
|
||||
"ui_footer_links": CONFIG.get("footer_links"),
|
||||
"ui_footer_links": CONFIG.y("footer_links"),
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
@ -43,7 +43,7 @@ class TestTenants(APITestCase):
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "custom",
|
||||
"matched_domain": "bar.baz",
|
||||
"ui_footer_links": CONFIG.get("footer_links"),
|
||||
"ui_footer_links": CONFIG.y("footer_links"),
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
@ -59,7 +59,7 @@ class TestTenants(APITestCase):
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "authentik",
|
||||
"matched_domain": "fallback",
|
||||
"ui_footer_links": CONFIG.get("footer_links"),
|
||||
"ui_footer_links": CONFIG.y("footer_links"),
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
|
@ -36,7 +36,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
|
||||
trace = span.to_traceparent()
|
||||
return {
|
||||
"tenant": tenant,
|
||||
"footer_links": CONFIG.get("footer_links"),
|
||||
"footer_links": CONFIG.y("footer_links"),
|
||||
"sentry_trace": trace,
|
||||
"version": get_full_version(),
|
||||
}
|
||||
|
@ -94,21 +94,21 @@ entries:
|
||||
prompt_data = request.context.get("prompt_data")
|
||||
|
||||
if not request.user.group_attributes(request.http_request).get(
|
||||
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.get_bool("default_user_change_email", True)
|
||||
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.y_bool("default_user_change_email", True)
|
||||
):
|
||||
if prompt_data.get("email") != request.user.email:
|
||||
ak_message("Not allowed to change email address.")
|
||||
return False
|
||||
|
||||
if not request.user.group_attributes(request.http_request).get(
|
||||
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.get_bool("default_user_change_name", True)
|
||||
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.y_bool("default_user_change_name", True)
|
||||
):
|
||||
if prompt_data.get("name") != request.user.name:
|
||||
ak_message("Not allowed to change name.")
|
||||
return False
|
||||
|
||||
if not request.user.group_attributes(request.http_request).get(
|
||||
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.get_bool("default_user_change_username", True)
|
||||
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.y_bool("default_user_change_username", True)
|
||||
):
|
||||
if prompt_data.get("username") != request.user.username:
|
||||
ak_message("Not allowed to change username.")
|
||||
|
@ -3213,6 +3213,7 @@
|
||||
"authentik.crypto",
|
||||
"authentik.events",
|
||||
"authentik.flows",
|
||||
"authentik.lib",
|
||||
"authentik.outposts",
|
||||
"authentik.policies.dummy",
|
||||
"authentik.policies.event_matcher",
|
||||
@ -3979,16 +3980,6 @@
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Path"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"internal",
|
||||
"external",
|
||||
"service_account",
|
||||
"internal_service_account"
|
||||
],
|
||||
"title": "Type"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@ -4180,16 +4171,6 @@
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Path"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"internal",
|
||||
"external",
|
||||
"service_account",
|
||||
"internal_service_account"
|
||||
],
|
||||
"title": "Type"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@ -4385,16 +4366,6 @@
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Path"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"internal",
|
||||
"external",
|
||||
"service_account",
|
||||
"internal_service_account"
|
||||
],
|
||||
"title": "Type"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@ -6551,16 +6522,6 @@
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Path"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"internal",
|
||||
"external",
|
||||
"service_account",
|
||||
"internal_service_account"
|
||||
],
|
||||
"title": "Type"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@ -7296,16 +7257,6 @@
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Path"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"internal",
|
||||
"external",
|
||||
"service_account",
|
||||
"internal_service_account"
|
||||
],
|
||||
"title": "Type"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@ -8383,16 +8334,6 @@
|
||||
"minLength": 1,
|
||||
"title": "Path"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"internal",
|
||||
"external",
|
||||
"service_account",
|
||||
"internal_service_account"
|
||||
],
|
||||
"title": "Type"
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
|
36
docker-compose.override.yml
Normal file
36
docker-compose.override.yml
Normal file
@ -0,0 +1,36 @@
|
||||
# This file is used for development and debugging, and should not be used for production instances
|
||||
|
||||
version: '3.5'
|
||||
|
||||
services:
|
||||
flower:
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.5.4}
|
||||
restart: unless-stopped
|
||||
command: worker-status
|
||||
environment:
|
||||
AUTHENTIK_REDIS__HOST: redis
|
||||
AUTHENTIK_POSTGRESQL__HOST: postgresql
|
||||
AUTHENTIK_POSTGRESQL__USER: ${PG_USER:-authentik}
|
||||
AUTHENTIK_POSTGRESQL__NAME: ${PG_DB:-authentik}
|
||||
AUTHENTIK_POSTGRESQL__PASSWORD: ${PG_PASS}
|
||||
env_file:
|
||||
- .env
|
||||
ports:
|
||||
- "9001:9000"
|
||||
depends_on:
|
||||
- postgresql
|
||||
- redis
|
||||
server:
|
||||
environment:
|
||||
AUTHENTIK_REMOTE_DEBUG: "true"
|
||||
PYDEVD_THREAD_DUMP_ON_WARN_EVALUATION_TIMEOUT: "true"
|
||||
ports:
|
||||
- 6800:6800
|
||||
worker:
|
||||
environment:
|
||||
CELERY_RDB_HOST: "0.0.0.0"
|
||||
CELERY_RDBSIG: "1"
|
||||
AUTHENTIK_REMOTE_DEBUG: "true"
|
||||
PYDEVD_THREAD_DUMP_ON_WARN_EVALUATION_TIMEOUT: "true"
|
||||
ports:
|
||||
- 6900:6900
|
@ -32,7 +32,7 @@ services:
|
||||
volumes:
|
||||
- redis:/data
|
||||
server:
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.6.1}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.6.2}
|
||||
restart: unless-stopped
|
||||
command: server
|
||||
environment:
|
||||
@ -53,7 +53,7 @@ services:
|
||||
- postgresql
|
||||
- redis
|
||||
worker:
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.6.1}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.6.2}
|
||||
restart: unless-stopped
|
||||
command: worker
|
||||
environment:
|
||||
|
2
go.mod
2
go.mod
@ -26,7 +26,7 @@ require (
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
goauthentik.io/api/v3 v3.2023061.6
|
||||
goauthentik.io/api/v3 v3.2023054.4
|
||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||
golang.org/x/oauth2 v0.10.0
|
||||
golang.org/x/sync v0.3.0
|
||||
|
4
go.sum
4
go.sum
@ -1070,8 +1070,8 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe
|
||||
go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U=
|
||||
go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U=
|
||||
go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0=
|
||||
goauthentik.io/api/v3 v3.2023061.6 h1:4zbo0Dtx42HLYObizIlTWAk7iBvCv9kmCvzBxMElkIk=
|
||||
goauthentik.io/api/v3 v3.2023061.6/go.mod h1:tC7qK9VSP0zJah5p5xHFnjZt/4dAkXVwcrWyZNGYhwQ=
|
||||
goauthentik.io/api/v3 v3.2023054.4 h1:wnONALlxADR42TpW5xKKsGkJ/G8oNDQsWiwdlMsG2Ig=
|
||||
goauthentik.io/api/v3 v3.2023054.4/go.mod h1:tC7qK9VSP0zJah5p5xHFnjZt/4dAkXVwcrWyZNGYhwQ=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user