root: Multi-tenancy (#7590)

* tenants -> brands, init new tenant model, migrate some config to tenants

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* setup logging for tenants

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* configure celery and cache

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* small fixes, runs

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* task fixes, creation of tenant now works by cloning a template schema, some other small stuff

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix-tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* upstream fixes

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix-pylint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix avatar tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* migrate config reputation_expiry as well

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix web rebase

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix migrations for template schema

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix migrations for template schema

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix migrations for template schema 3

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* revert reputation expiry migration

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix type

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix some more tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* website: tenants -> brands

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* try fixing e2e tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* start frontend :help:

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add ability to disable tenants api

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* delete embedded outpost if it is disabled

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* make sure embedded outpost is disabled when tenants are enabled

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* management commands: add --schema option where relevant

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* store files per-tenant

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix embedded outpost deletion

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix files migration

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add tenant api tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add domain tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add settings tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* make --schema-name default to public in mgmt commands

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* sources/ldap: make sure lock is per-tenant

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix stuff I broke

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix remaining failing tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* try fixing e2e tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* much better frontend, but save does not refresh form properly

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* update django-tenants with latest fixes

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* i18n-extract

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* review comments

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* move event_retention from brands to tenants

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* wip

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* root: add support for storing media files in S3

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* use permissions for settings api

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* blueprints: disable tenants management

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix embedded outpost create/delete logic

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* make gen

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* make sure prometheus metrics are correctly served

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* makefile: don't delete the go api client when not regenerating it

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* tenants api: add recovery group and token creation endpoints

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix startup

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix prometheus metrics

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix web stuff

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix migrations from stable

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix oauth source type import

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* Revert "fix oauth source type import"

This reverts commit d015fd0244.

* try with setting_changed signal

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* try with connection_created signal

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix scim tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix web after merge

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix enterprise settings

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* Revert "try with connection_created signal"

This reverts commit 764a999db8.

* Revert "try with setting_changed signal"

This reverts commit 32b40a3bbb.

* lib/expression: refactor expression compilation

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix django version

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix web after merge

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* relock poetry

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix reconcile

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* try running tenant save in a transaction

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* black

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* test: export postgres logs for debugging and use failfast

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* test: fix container name for logs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* do not copy tenant data

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* Revert "try running tenant save in a transaction"

This reverts commit da6dec5a61.

* Revert "do not copy tenant data"

This reverts commit d07ae9423672f068b0bd8be409ff9b58452a80f2.

* Revert "Revert "do not copy tenant data""

This reverts commit 4bffb19704.

* fix clone with nodata

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* why not

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* remove failfast

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove postgres query logging

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* update reconcile logic to clearly differentiate between tenant and global

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix reconcile app decorator

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* enable django checks

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* actually nodata was unnecessary as we're cloning from template and not from public

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* pylint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* update django-tenants with sequence fix

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* actually update

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix e2e tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add tests for settings api

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add tests for recovery api

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* recovery tests: do them on a new tenant

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* web: fix system status being degraded when embedded outpost is disabled

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix recovery tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix tenants tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint-fix

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint-fix

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* update UI

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add management command to create a tenant

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add docs

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* release notes

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* more docs

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* checklist

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* self review

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* spelling

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* make web after upgrading

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* remove extra xlif file

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* prettier

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* Revert "add management command to create a tenant"

This reverts commit 39d13c0447.

* split api into smaller files, only import urls when tenants is enabled

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rewite some things on the release notes

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* root: make sure install_id comes from public schema

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* require a license to use tenants

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix tenants tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fix files migration

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* release notes: add warning about user sessions being invalidated

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* remove api disabled test, we can't test for it

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

---------

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Marc 'risson' Schmitt
2024-01-23 14:28:06 +01:00
committed by GitHub
parent 73ddaf48be
commit abc0c2d2a2
227 changed files with 6554 additions and 2481 deletions

View File

@ -96,8 +96,14 @@ dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik
######################### #########################
gen-build: ## Extract the schema from the database gen-build: ## Extract the schema from the database
AUTHENTIK_DEBUG=true ak make_blueprint_schema > blueprints/schema.json AUTHENTIK_DEBUG=true \
AUTHENTIK_DEBUG=true ak spectacular --file schema.yml AUTHENTIK_TENANTS__ENABLED=true \
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
ak make_blueprint_schema > blueprints/schema.json
AUTHENTIK_DEBUG=true \
AUTHENTIK_TENANTS__ENABLED=true \
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
ak spectacular --file schema.yml
gen-changelog: ## (Release) generate the changelog based from the commits since the last tag gen-changelog: ## (Release) generate the changelog based from the commits since the last tag
git log --pretty=format:" - %s" $(shell git describe --tags $(shell git rev-list --tags --max-count=1))...$(shell git branch --show-current) | sort > changelog.md git log --pretty=format:" - %s" $(shell git describe --tags $(shell git rev-list --tags --max-count=1))...$(shell git branch --show-current) | sort > changelog.md
@ -116,12 +122,16 @@ gen-diff: ## (Release) generate the changelog diff between the current schema a
sed -i 's/}/&#125;/g' diff.md sed -i 's/}/&#125;/g' diff.md
npx prettier --write diff.md npx prettier --write diff.md
gen-clean: gen-clean-ts: ## Remove generated API client for Typescript
rm -rf gen-go-api/
rm -rf gen-ts-api/ rm -rf gen-ts-api/
rm -rf web/node_modules/@goauthentik/api/ rm -rf web/node_modules/@goauthentik/api/
gen-client-ts: ## Build and install the authentik API for Typescript into the authentik UI Application gen-clean-go: ## Remove generated API client for Go
rm -rf gen-go-api/
gen-clean: gen-clean-ts gen-clean-go ## Remove generated API clients
gen-client-ts: gen-clean-ts ## Build and install the authentik API for Typescript into the authentik UI Application
docker run \ docker run \
--rm -v ${PWD}:/local \ --rm -v ${PWD}:/local \
--user ${UID}:${GID} \ --user ${UID}:${GID} \
@ -137,7 +147,7 @@ gen-client-ts: ## Build and install the authentik API for Typescript into the a
cd gen-ts-api && npm i cd gen-ts-api && npm i
\cp -rfv gen-ts-api/* web/node_modules/@goauthentik/api \cp -rfv gen-ts-api/* web/node_modules/@goauthentik/api
gen-client-go: ## Build and install the authentik API for Golang gen-client-go: gen-clean-go ## Build and install the authentik API for Golang
mkdir -p ./gen-go-api ./gen-go-api/templates mkdir -p ./gen-go-api ./gen-go-api/templates
wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./gen-go-api/config.yaml wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./gen-go-api/config.yaml
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./gen-go-api/templates/README.mustache wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./gen-go-api/templates/README.mustache
@ -157,7 +167,7 @@ gen-client-go: ## Build and install the authentik API for Golang
gen-dev-config: ## Generate a local development config file gen-dev-config: ## Generate a local development config file
python -m scripts.generate_config python -m scripts.generate_config
gen: gen-build gen-clean gen-client-ts gen: gen-build gen-client-ts
######################### #########################
## Web ## Web

View File

@ -13,6 +13,7 @@ from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.lib.config import CONFIG
from authentik.lib.utils.reflection import get_env from authentik.lib.utils.reflection import get_env
from authentik.outposts.apps import MANAGED_OUTPOST from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import Outpost from authentik.outposts.models import Outpost
@ -37,8 +38,9 @@ class SystemInfoSerializer(PassiveSerializer):
http_host = SerializerMethodField() http_host = SerializerMethodField()
http_is_secure = SerializerMethodField() http_is_secure = SerializerMethodField()
runtime = SerializerMethodField() runtime = SerializerMethodField()
tenant = SerializerMethodField() brand = SerializerMethodField()
server_time = SerializerMethodField() server_time = SerializerMethodField()
embedded_outpost_disabled = SerializerMethodField()
embedded_outpost_host = SerializerMethodField() embedded_outpost_host = SerializerMethodField()
def get_http_headers(self, request: Request) -> dict[str, str]: def get_http_headers(self, request: Request) -> dict[str, str]:
@ -69,14 +71,18 @@ class SystemInfoSerializer(PassiveSerializer):
"uname": " ".join(platform.uname()), "uname": " ".join(platform.uname()),
} }
def get_tenant(self, request: Request) -> str: def get_brand(self, request: Request) -> str:
"""Currently active tenant""" """Currently active brand"""
return str(request._request.tenant) return str(request._request.brand)
def get_server_time(self, request: Request) -> datetime: def get_server_time(self, request: Request) -> datetime:
"""Current server time""" """Current server time"""
return now() return now()
def get_embedded_outpost_disabled(self, request: Request) -> bool:
"""Whether the embedded outpost is disabled"""
return CONFIG.get_bool("outposts.disable_embedded_outpost", False)
def get_embedded_outpost_host(self, request: Request) -> str: def get_embedded_outpost_host(self, request: Request) -> str:
"""Get the FQDN configured on the embedded outpost""" """Get the FQDN configured on the embedded outpost"""
outposts = Outpost.objects.filter(managed=MANAGED_OUTPOST) outposts = Outpost.objects.filter(managed=MANAGED_OUTPOST)

View File

@ -15,6 +15,6 @@ class AuthentikAdminConfig(ManagedAppConfig):
verbose_name = "authentik Admin" verbose_name = "authentik Admin"
default = True default = True
def reconcile_load_admin_signals(self): def reconcile_global_load_admin_signals(self):
"""Load admin signals""" """Load admin signals"""
self.import_module("authentik.admin.signals") self.import_module("authentik.admin.signals")

View File

@ -3,7 +3,7 @@
{% load static %} {% load static %}
{% block title %} {% block title %}
API Browser - {{ tenant.branding_title }} API Browser - {{ brand.branding_title }}
{% endblock %} {% endblock %}
{% block head %} {% block head %}

View File

@ -72,7 +72,7 @@ class ConfigView(APIView):
for processor in get_context_processors(): for processor in get_context_processors():
if cap := processor.capability(): if cap := processor.capability():
caps.append(cap) caps.append(cap)
if CONFIG.get_bool("impersonation"): if self.request.tenant.impersonation:
caps.append(Capabilities.CAN_IMPERSONATE) caps.append(Capabilities.CAN_IMPERSONATE)
if settings.DEBUG: # pragma: no cover if settings.DEBUG: # pragma: no cover
caps.append(Capabilities.CAN_DEBUG) caps.append(Capabilities.CAN_DEBUG)

View File

@ -13,21 +13,23 @@ class ManagedAppConfig(AppConfig):
_logger: BoundLogger _logger: BoundLogger
RECONCILE_GLOBAL_PREFIX: str = "reconcile_global_"
RECONCILE_TENANT_PREFIX: str = "reconcile_tenant_"
def __init__(self, app_name: str, *args, **kwargs) -> None: def __init__(self, app_name: str, *args, **kwargs) -> None:
super().__init__(app_name, *args, **kwargs) super().__init__(app_name, *args, **kwargs)
self._logger = get_logger().bind(app_name=app_name) self._logger = get_logger().bind(app_name=app_name)
def ready(self) -> None: def ready(self) -> None:
self.reconcile() self.reconcile_global()
self.reconcile_tenant()
return super().ready() return super().ready()
def import_module(self, path: str): def import_module(self, path: str):
"""Load module""" """Load module"""
import_module(path) import_module(path)
def reconcile(self) -> None: def _reconcile(self, prefix: str) -> None:
"""reconcile ourselves"""
prefix = "reconcile_"
for meth_name in dir(self): for meth_name in dir(self):
meth = getattr(self, meth_name) meth = getattr(self, meth_name)
if not ismethod(meth): if not ismethod(meth):
@ -42,6 +44,29 @@ class ManagedAppConfig(AppConfig):
except (DatabaseError, ProgrammingError, InternalError) as exc: except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.warning("Failed to run reconcile", name=name, exc=exc) self._logger.warning("Failed to run reconcile", name=name, exc=exc)
def reconcile_tenant(self) -> None:
"""reconcile ourselves for tenanted methods"""
from authentik.tenants.models import Tenant
try:
tenants = list(Tenant.objects.filter(ready=True))
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to get tenants to run reconcile", exc=exc)
return
for tenant in tenants:
with tenant:
self._reconcile(self.RECONCILE_TENANT_PREFIX)
def reconcile_global(self) -> None:
"""
reconcile ourselves for global methods.
Used for signals, tasks, etc. Database queries should not be made in here.
"""
from django_tenants.utils import get_public_schema_name, schema_context
with schema_context(get_public_schema_name()):
self._reconcile(self.RECONCILE_GLOBAL_PREFIX)
class AuthentikBlueprintsConfig(ManagedAppConfig): class AuthentikBlueprintsConfig(ManagedAppConfig):
"""authentik Blueprints app""" """authentik Blueprints app"""
@ -51,11 +76,11 @@ class AuthentikBlueprintsConfig(ManagedAppConfig):
verbose_name = "authentik Blueprints" verbose_name = "authentik Blueprints"
default = True default = True
def reconcile_load_blueprints_v1_tasks(self): def reconcile_global_load_blueprints_v1_tasks(self):
"""Load v1 tasks""" """Load v1 tasks"""
self.import_module("authentik.blueprints.v1.tasks") self.import_module("authentik.blueprints.v1.tasks")
def reconcile_blueprints_discovery(self): def reconcile_tenant_blueprints_discovery(self):
"""Run blueprint discovery""" """Run blueprint discovery"""
from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints

View File

@ -6,6 +6,7 @@ from structlog.stdlib import get_logger
from authentik.blueprints.models import BlueprintInstance from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.v1.importer import Importer from authentik.blueprints.v1.importer import Importer
from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
@ -16,6 +17,8 @@ class Command(BaseCommand):
@no_translations @no_translations
def handle(self, *args, **options): def handle(self, *args, **options):
"""Apply all blueprints in order, abort when one fails to import""" """Apply all blueprints in order, abort when one fails to import"""
for tenant in Tenant.objects.filter(ready=True):
with tenant:
for blueprint_path in options.get("blueprints", []): for blueprint_path in options.get("blueprints", []):
content = BlueprintInstance(path=blueprint_path).retrieve() content = BlueprintInstance(path=blueprint_path).retrieve()
importer = Importer.from_string(content) importer = Importer.from_string(content)

View File

@ -1,17 +1,18 @@
"""Export blueprint of current authentik install""" """Export blueprint of current authentik install"""
from django.core.management.base import BaseCommand, no_translations from django.core.management.base import no_translations
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.blueprints.v1.exporter import Exporter from authentik.blueprints.v1.exporter import Exporter
from authentik.tenants.management import TenantCommand
LOGGER = get_logger() LOGGER = get_logger()
class Command(BaseCommand): class Command(TenantCommand):
"""Export blueprint of current authentik install""" """Export blueprint of current authentik install"""
@no_translations @no_translations
def handle(self, *args, **options): def handle_per_tenant(self, *args, **options):
"""Export blueprint of current authentik install""" """Export blueprint of current authentik install"""
exporter = Exporter() exporter = Exporter()
self.stdout.write(exporter.export_to_string()) self.stdout.write(exporter.export_to_string())

View File

@ -14,7 +14,7 @@ from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_SYSTEM
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
def check_blueprint_v1_file(BlueprintInstance: type, path: Path): def check_blueprint_v1_file(BlueprintInstance: type, db_alias, path: Path):
"""Check if blueprint should be imported""" """Check if blueprint should be imported"""
from authentik.blueprints.models import BlueprintInstanceStatus from authentik.blueprints.models import BlueprintInstanceStatus
from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata
@ -29,7 +29,9 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
if version != 1: if version != 1:
return return
blueprint_file.seek(0) blueprint_file.seek(0)
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first() instance: BlueprintInstance = (
BlueprintInstance.objects.using(db_alias).filter(path=path).first()
)
rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir"))) rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
meta = None meta = None
if metadata: if metadata:
@ -37,7 +39,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
if meta.labels.get(LABEL_AUTHENTIK_INSTANTIATE, "").lower() == "false": if meta.labels.get(LABEL_AUTHENTIK_INSTANTIATE, "").lower() == "false":
return return
if not instance: if not instance:
instance = BlueprintInstance( BlueprintInstance.objects.using(db_alias).create(
name=meta.name if meta else str(rel_path), name=meta.name if meta else str(rel_path),
path=str(rel_path), path=str(rel_path),
context={}, context={},
@ -47,7 +49,6 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
last_applied_hash="", last_applied_hash="",
metadata=metadata or {}, metadata=metadata or {},
) )
instance.save()
def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
@ -56,7 +57,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True): for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True):
check_blueprint_v1_file(BlueprintInstance, Path(file)) check_blueprint_v1_file(BlueprintInstance, db_alias, Path(file))
for blueprint in BlueprintInstance.objects.using(db_alias).all(): for blueprint in BlueprintInstance.objects.using(db_alias).all():
# If we already have flows (and we should always run before flow migrations) # If we already have flows (and we should always run before flow migrations)

View File

@ -38,7 +38,7 @@ def reconcile_app(app_name: str):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
config = apps.get_app_config(app_name) config = apps.get_app_config(app_name)
if isinstance(config, ManagedAppConfig): if isinstance(config, ManagedAppConfig):
config.reconcile() config.ready()
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper

View File

@ -7,16 +7,16 @@ from django.test import TransactionTestCase
from authentik.blueprints.models import BlueprintInstance from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.blueprints.v1.importer import Importer from authentik.blueprints.v1.importer import Importer
from authentik.tenants.models import Tenant from authentik.brands.models import Brand
class TestPackaged(TransactionTestCase): class TestPackaged(TransactionTestCase):
"""Empty class, test methods are added dynamically""" """Empty class, test methods are added dynamically"""
@apply_blueprint("default/default-tenant.yaml") @apply_blueprint("default/default-brand.yaml")
def test_decorator_static(self): def test_decorator_static(self):
"""Test @apply_blueprint decorator""" """Test @apply_blueprint decorator"""
self.assertTrue(Tenant.objects.filter(domain="authentik-default").exists()) self.assertTrue(Brand.objects.filter(domain="authentik-default").exists())
def blueprint_tester(file_name: Path) -> Callable: def blueprint_tester(file_name: Path) -> Callable:

View File

@ -43,6 +43,7 @@ from authentik.lib.sentry import SentryIgnoredException
from authentik.outposts.models import OutpostServiceConnection from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel from authentik.policies.models import Policy, PolicyBindingModel
from authentik.providers.scim.models import SCIMGroup, SCIMUser from authentik.providers.scim.models import SCIMGroup, SCIMUser
from authentik.tenants.models import Tenant
# Context set when the serializer is created in a blueprint context # Context set when the serializer is created in a blueprint context
# Update website/developer-docs/blueprints/v1/models.md when used # Update website/developer-docs/blueprints/v1/models.md when used
@ -57,6 +58,7 @@ def excluded_models() -> list[type[Model]]:
from django.contrib.auth.models import User as DjangoUser from django.contrib.auth.models import User as DjangoUser
return ( return (
Tenant,
DjangoUser, DjangoUser,
DjangoGroup, DjangoGroup,
# Base classes # Base classes

View File

@ -38,6 +38,7 @@ from authentik.events.monitored_tasks import (
from authentik.events.utils import sanitize_dict from authentik.events.utils import sanitize_dict
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
_file_watcher_started = False _file_watcher_started = False
@ -78,6 +79,11 @@ class BlueprintEventHandler(FileSystemEventHandler):
root = Path(CONFIG.get("blueprints_dir")).absolute() root = Path(CONFIG.get("blueprints_dir")).absolute()
path = Path(event.src_path).absolute() path = Path(event.src_path).absolute()
rel_path = str(path.relative_to(root)) rel_path = str(path.relative_to(root))
for tenant in Tenant.objects.filter(ready=True):
with tenant:
root = Path(CONFIG.get("blueprints_dir")).absolute()
path = Path(event.src_path).absolute()
rel_path = str(path.relative_to(root))
if isinstance(event, FileCreatedEvent): if isinstance(event, FileCreatedEvent):
LOGGER.debug("new blueprint file created, starting discovery", path=rel_path) LOGGER.debug("new blueprint file created, starting discovery", path=rel_path)
blueprints_discovery.delay(rel_path) blueprints_discovery.delay(rel_path)

View File

View File

@ -1,4 +1,4 @@
"""Serializer for tenant models""" """Serializer for brands models"""
from typing import Any from typing import Any
from django.db import models from django.db import models
@ -14,10 +14,10 @@ from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.api.authorization import SecretKeyFilter from authentik.api.authorization import SecretKeyFilter
from authentik.brands.models import Brand
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.lib.config import CONFIG from authentik.tenants.utils import get_current_tenant
from authentik.tenants.models import Tenant
class FooterLinkSerializer(PassiveSerializer): class FooterLinkSerializer(PassiveSerializer):
@ -27,22 +27,22 @@ class FooterLinkSerializer(PassiveSerializer):
name = CharField(read_only=True) name = CharField(read_only=True)
class TenantSerializer(ModelSerializer): class BrandSerializer(ModelSerializer):
"""Tenant Serializer""" """Brand Serializer"""
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
if attrs.get("default", False): if attrs.get("default", False):
tenants = Tenant.objects.filter(default=True) brands = Brand.objects.filter(default=True)
if self.instance: if self.instance:
tenants = tenants.exclude(pk=self.instance.pk) brands = brands.exclude(pk=self.instance.pk)
if tenants.exists(): if brands.exists():
raise ValidationError({"default": "Only a single Tenant can be set as default."}) raise ValidationError({"default": "Only a single brand can be set as default."})
return super().validate(attrs) return super().validate(attrs)
class Meta: class Meta:
model = Tenant model = Brand
fields = [ fields = [
"tenant_uuid", "brand_uuid",
"domain", "domain",
"default", "default",
"branding_title", "branding_title",
@ -54,7 +54,6 @@ class TenantSerializer(ModelSerializer):
"flow_unenrollment", "flow_unenrollment",
"flow_user_settings", "flow_user_settings",
"flow_device_code", "flow_device_code",
"event_retention",
"web_certificate", "web_certificate",
"attributes", "attributes",
] ]
@ -68,8 +67,13 @@ class Themes(models.TextChoices):
DARK = "dark" DARK = "dark"
class CurrentTenantSerializer(PassiveSerializer): def get_default_ui_footer_links():
"""Partial tenant information for styling""" """Get default UI footer links based on current tenant settings"""
return get_current_tenant().footer_links
class CurrentBrandSerializer(PassiveSerializer):
"""Partial brand information for styling"""
matched_domain = CharField(source="domain") matched_domain = CharField(source="domain")
branding_title = CharField() branding_title = CharField()
@ -78,7 +82,7 @@ class CurrentTenantSerializer(PassiveSerializer):
ui_footer_links = ListField( ui_footer_links = ListField(
child=FooterLinkSerializer(), child=FooterLinkSerializer(),
read_only=True, read_only=True,
default=CONFIG.get("footer_links", []), default=get_default_ui_footer_links,
) )
ui_theme = ChoiceField( ui_theme = ChoiceField(
choices=Themes.choices, choices=Themes.choices,
@ -97,18 +101,18 @@ class CurrentTenantSerializer(PassiveSerializer):
default_locale = CharField(read_only=True) default_locale = CharField(read_only=True)
class TenantViewSet(UsedByMixin, ModelViewSet): class BrandViewSet(UsedByMixin, ModelViewSet):
"""Tenant Viewset""" """Brand Viewset"""
queryset = Tenant.objects.all() queryset = Brand.objects.all()
serializer_class = TenantSerializer serializer_class = BrandSerializer
search_fields = [ search_fields = [
"domain", "domain",
"branding_title", "branding_title",
"web_certificate__name", "web_certificate__name",
] ]
filterset_fields = [ filterset_fields = [
"tenant_uuid", "brand_uuid",
"domain", "domain",
"default", "default",
"branding_title", "branding_title",
@ -120,7 +124,6 @@ class TenantViewSet(UsedByMixin, ModelViewSet):
"flow_unenrollment", "flow_unenrollment",
"flow_user_settings", "flow_user_settings",
"flow_device_code", "flow_device_code",
"event_retention",
"web_certificate", "web_certificate",
] ]
ordering = ["domain"] ordering = ["domain"]
@ -128,10 +131,10 @@ class TenantViewSet(UsedByMixin, ModelViewSet):
filter_backends = [SecretKeyFilter, OrderingFilter, SearchFilter] filter_backends = [SecretKeyFilter, OrderingFilter, SearchFilter]
@extend_schema( @extend_schema(
responses=CurrentTenantSerializer(many=False), responses=CurrentBrandSerializer(many=False),
) )
@action(methods=["GET"], detail=False, permission_classes=[AllowAny]) @action(methods=["GET"], detail=False, permission_classes=[AllowAny])
def current(self, request: Request) -> Response: def current(self, request: Request) -> Response:
"""Get current tenant""" """Get current brand"""
tenant: Tenant = request._request.tenant brand: Brand = request._request.brand
return Response(CurrentTenantSerializer(tenant).data) return Response(CurrentBrandSerializer(brand).data)

10
authentik/brands/apps.py Normal file
View File

@ -0,0 +1,10 @@
"""authentik brands app"""
from django.apps import AppConfig
class AuthentikBrandsConfig(AppConfig):
"""authentik Brand app"""
name = "authentik.brands"
label = "authentik_brands"
verbose_name = "authentik Brands"

View File

@ -0,0 +1,26 @@
"""Inject brand into current request"""
from typing import Callable
from django.http.request import HttpRequest
from django.http.response import HttpResponse
from django.utils.translation import activate
from authentik.brands.utils import get_brand_for_request
class BrandMiddleware:
"""Add current brand to http request"""
get_response: Callable[[HttpRequest], HttpResponse]
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
if not hasattr(request, "brand"):
brand = get_brand_for_request(request)
setattr(request, "brand", brand)
locale = brand.default_locale
if locale != "":
activate(locale)
return self.get_response(request)

View File

@ -10,11 +10,11 @@ import authentik.lib.utils.time
class Migration(migrations.Migration): class Migration(migrations.Migration):
replaces = [ replaces = [
("authentik_tenants", "0001_initial"), ("authentik_brands", "0001_initial"),
("authentik_tenants", "0002_default"), ("authentik_brands", "0002_default"),
("authentik_tenants", "0003_tenant_branding_favicon"), ("authentik_brands", "0003_tenant_branding_favicon"),
("authentik_tenants", "0004_tenant_event_retention"), ("authentik_brands", "0004_tenant_event_retention"),
("authentik_tenants", "0005_tenant_web_certificate"), ("authentik_brands", "0005_tenant_web_certificate"),
] ]
initial = True initial = True
@ -25,7 +25,7 @@ class Migration(migrations.Migration):
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name="Tenant", name="Brand",
fields=[ fields=[
( (
"tenant_uuid", "tenant_uuid",
@ -37,7 +37,7 @@ class Migration(migrations.Migration):
"domain", "domain",
models.TextField( models.TextField(
help_text=( help_text=(
"Domain that activates this tenant. Can be a superset, i.e. `a.b` for" "Domain that activates this brand. Can be a superset, i.e. `a.b` for"
" `aa.b` and `ba.b`" " `aa.b` and `ba.b`"
) )
), ),
@ -53,7 +53,7 @@ class Migration(migrations.Migration):
models.ForeignKey( models.ForeignKey(
null=True, null=True,
on_delete=django.db.models.deletion.SET_NULL, on_delete=django.db.models.deletion.SET_NULL,
related_name="tenant_authentication", related_name="brand_authentication",
to="authentik_flows.flow", to="authentik_flows.flow",
), ),
), ),
@ -62,7 +62,7 @@ class Migration(migrations.Migration):
models.ForeignKey( models.ForeignKey(
null=True, null=True,
on_delete=django.db.models.deletion.SET_NULL, on_delete=django.db.models.deletion.SET_NULL,
related_name="tenant_invalidation", related_name="brand_invalidation",
to="authentik_flows.flow", to="authentik_flows.flow",
), ),
), ),
@ -71,7 +71,7 @@ class Migration(migrations.Migration):
models.ForeignKey( models.ForeignKey(
null=True, null=True,
on_delete=django.db.models.deletion.SET_NULL, on_delete=django.db.models.deletion.SET_NULL,
related_name="tenant_recovery", related_name="brand_recovery",
to="authentik_flows.flow", to="authentik_flows.flow",
), ),
), ),
@ -80,23 +80,23 @@ class Migration(migrations.Migration):
models.ForeignKey( models.ForeignKey(
null=True, null=True,
on_delete=django.db.models.deletion.SET_NULL, on_delete=django.db.models.deletion.SET_NULL,
related_name="tenant_unenrollment", related_name="brand_unenrollment",
to="authentik_flows.flow", to="authentik_flows.flow",
), ),
), ),
], ],
options={ options={
"verbose_name": "Tenant", "verbose_name": "Brand",
"verbose_name_plural": "Tenants", "verbose_name_plural": "Brands",
}, },
), ),
migrations.AddField( migrations.AddField(
model_name="tenant", model_name="brand",
name="branding_favicon", name="branding_favicon",
field=models.TextField(default="/static/dist/assets/icons/icon.png"), field=models.TextField(default="/static/dist/assets/icons/icon.png"),
), ),
migrations.AddField( migrations.AddField(
model_name="tenant", model_name="brand",
name="event_retention", name="event_retention",
field=models.TextField( field=models.TextField(
default="days=365", default="days=365",
@ -108,7 +108,7 @@ class Migration(migrations.Migration):
), ),
), ),
migrations.AddField( migrations.AddField(
model_name="tenant", model_name="brand",
name="web_certificate", name="web_certificate",
field=models.ForeignKey( field=models.ForeignKey(
default=None, default=None,

View File

@ -8,17 +8,17 @@ class Migration(migrations.Migration):
dependencies = [ dependencies = [
("authentik_stages_prompt", "0007_prompt_placeholder_expression"), ("authentik_stages_prompt", "0007_prompt_placeholder_expression"),
("authentik_flows", "0021_auto_20211227_2103"), ("authentik_flows", "0021_auto_20211227_2103"),
("authentik_tenants", "0001_squashed_0005_tenant_web_certificate"), ("authentik_brands", "0001_squashed_0005_tenant_web_certificate"),
] ]
operations = [ operations = [
migrations.AddField( migrations.AddField(
model_name="tenant", model_name="brand",
name="flow_user_settings", name="flow_user_settings",
field=models.ForeignKey( field=models.ForeignKey(
null=True, null=True,
on_delete=django.db.models.deletion.SET_NULL, on_delete=django.db.models.deletion.SET_NULL,
related_name="tenant_user_settings", related_name="brand_user_settings",
to="authentik_flows.flow", to="authentik_flows.flow",
), ),
), ),

View File

@ -5,12 +5,12 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
("authentik_tenants", "0002_tenant_flow_user_settings"), ("authentik_brands", "0002_tenant_flow_user_settings"),
] ]
operations = [ operations = [
migrations.AddField( migrations.AddField(
model_name="tenant", model_name="brand",
name="attributes", name="attributes",
field=models.JSONField(blank=True, default=dict), field=models.JSONField(blank=True, default=dict),
), ),

View File

@ -7,17 +7,17 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
("authentik_flows", "0023_flow_denied_action"), ("authentik_flows", "0023_flow_denied_action"),
("authentik_tenants", "0003_tenant_attributes"), ("authentik_brands", "0003_tenant_attributes"),
] ]
operations = [ operations = [
migrations.AddField( migrations.AddField(
model_name="tenant", model_name="brand",
name="flow_device_code", name="flow_device_code",
field=models.ForeignKey( field=models.ForeignKey(
null=True, null=True,
on_delete=django.db.models.deletion.SET_NULL, on_delete=django.db.models.deletion.SET_NULL,
related_name="tenant_device_code", related_name="brand_device_code",
to="authentik_flows.flow", to="authentik_flows.flow",
), ),
), ),

View File

@ -0,0 +1,21 @@
# Generated by Django 4.2.7 on 2023-12-12 06:41
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_brands", "0004_tenant_flow_device_code"),
]
operations = [
migrations.RenameField(
model_name="brand",
old_name="tenant_uuid",
new_name="brand_uuid",
),
migrations.RemoveField(
model_name="brand",
name="event_retention",
),
]

View File

View File

@ -0,0 +1,85 @@
"""brand models"""
from uuid import uuid4
from django.db import models
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.crypto.models import CertificateKeyPair
from authentik.flows.models import Flow
from authentik.lib.models import SerializerModel
LOGGER = get_logger()
class Brand(SerializerModel):
"""Single brand"""
brand_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
domain = models.TextField(
help_text=_(
"Domain that activates this brand. Can be a superset, i.e. `a.b` for `aa.b` and `ba.b`"
)
)
default = models.BooleanField(
default=False,
)
branding_title = models.TextField(default="authentik")
branding_logo = models.TextField(default="/static/dist/assets/icons/icon_left_brand.svg")
branding_favicon = models.TextField(default="/static/dist/assets/icons/icon.png")
flow_authentication = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_authentication"
)
flow_invalidation = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_invalidation"
)
flow_recovery = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_recovery"
)
flow_unenrollment = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_unenrollment"
)
flow_user_settings = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_user_settings"
)
flow_device_code = models.ForeignKey(
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code"
)
web_certificate = models.ForeignKey(
CertificateKeyPair,
null=True,
default=None,
on_delete=models.SET_DEFAULT,
help_text=_("Web Certificate used by the authentik Core webserver."),
)
attributes = models.JSONField(default=dict, blank=True)
@property
def serializer(self) -> Serializer:
from authentik.brands.api import BrandSerializer
return BrandSerializer
@property
def default_locale(self) -> str:
"""Get default locale"""
try:
return self.attributes.get("settings", {}).get("locale", "")
# pylint: disable=broad-except
except Exception as exc:
LOGGER.warning("Failed to get default locale", exc=exc)
return ""
def __str__(self) -> str:
if self.default:
return "Default brand"
return f"Brand {self.domain}"
class Meta:
verbose_name = _("Brand")
verbose_name_plural = _("Brands")

76
authentik/brands/tests.py Normal file
View File

@ -0,0 +1,76 @@
"""Test brands"""
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.brands.api import Themes
from authentik.brands.models import Brand
from authentik.core.tests.utils import create_test_admin_user, create_test_brand
class TestBrands(APITestCase):
"""Test brands"""
def test_current_brand(self):
"""Test Current brand API"""
brand = create_test_brand()
self.assertJSONEqual(
self.client.get(reverse("authentik_api:brand-current")).content.decode(),
{
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
"branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "authentik",
"matched_domain": brand.domain,
"ui_footer_links": [],
"ui_theme": Themes.AUTOMATIC,
"default_locale": "",
},
)
def test_brand_subdomain(self):
"""Test Current brand API"""
Brand.objects.all().delete()
Brand.objects.create(domain="bar.baz", branding_title="custom")
self.assertJSONEqual(
self.client.get(
reverse("authentik_api:brand-current"), HTTP_HOST="foo.bar.baz"
).content.decode(),
{
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
"branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "custom",
"matched_domain": "bar.baz",
"ui_footer_links": [],
"ui_theme": Themes.AUTOMATIC,
"default_locale": "",
},
)
def test_fallback(self):
"""Test fallback brand"""
Brand.objects.all().delete()
self.assertJSONEqual(
self.client.get(reverse("authentik_api:brand-current")).content.decode(),
{
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
"branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "authentik",
"matched_domain": "fallback",
"ui_footer_links": [],
"ui_theme": Themes.AUTOMATIC,
"default_locale": "",
},
)
def test_create_default_multiple(self):
"""Test attempted creation of multiple default brands"""
Brand.objects.create(
domain="foo",
default=True,
branding_title="custom",
)
user = create_test_admin_user()
self.client.force_login(user)
response = self.client.post(
reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True}
)
self.assertEqual(response.status_code, 400)

6
authentik/brands/urls.py Normal file
View File

@ -0,0 +1,6 @@
"""API URLs"""
from authentik.brands.api import BrandViewSet
api_urlpatterns = [
("core/brands", BrandViewSet),
]

42
authentik/brands/utils.py Normal file
View File

@ -0,0 +1,42 @@
"""Brand utilities"""
from typing import Any
from django.db.models import F, Q
from django.db.models import Value as V
from django.http.request import HttpRequest
from sentry_sdk.hub import Hub
from authentik import get_full_version
from authentik.brands.models import Brand
from authentik.tenants.utils import get_current_tenant
_q_default = Q(default=True)
DEFAULT_BRAND = Brand(domain="fallback")
def get_brand_for_request(request: HttpRequest) -> Brand:
"""Get brand object for current request"""
db_brands = (
Brand.objects.annotate(host_domain=V(request.get_host()))
.filter(Q(host_domain__iendswith=F("domain")) | _q_default)
.order_by("default")
)
brands = list(db_brands.all())
if len(brands) < 1:
return DEFAULT_BRAND
return brands[0]
def context_processor(request: HttpRequest) -> dict[str, Any]:
"""Context Processor that injects brand object into every template"""
brand = getattr(request, "brand", DEFAULT_BRAND)
trace = ""
span = Hub.current.scope.span
if span:
trace = span.to_traceparent()
return {
"brand": brand,
"footer_links": get_current_tenant().footer_links,
"sentry_trace": trace,
"version": get_full_version(),
}

View File

@ -50,6 +50,7 @@ from structlog.stdlib import get_logger
from authentik.admin.api.metrics import CoordinateSerializer from authentik.admin.api.metrics import CoordinateSerializer
from authentik.api.decorators import permission_required from authentik.api.decorators import permission_required
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.brands.models import Brand
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField, LinkSerializer, PassiveSerializer from authentik.core.api.utils import JSONDictField, LinkSerializer, PassiveSerializer
from authentik.core.middleware import ( from authentik.core.middleware import (
@ -71,11 +72,9 @@ from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import FlowToken from authentik.flows.models import FlowToken
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
from authentik.flows.views.executor import QS_KEY_TOKEN from authentik.flows.views.executor import QS_KEY_TOKEN
from authentik.lib.config import CONFIG
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
from authentik.stages.email.tasks import send_mails from authentik.stages.email.tasks import send_mails
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
@ -221,7 +220,7 @@ class UserSelfSerializer(ModelSerializer):
} }
def get_settings(self, user: User) -> dict[str, Any]: def get_settings(self, user: User) -> dict[str, Any]:
"""Get user settings with tenant and group settings applied""" """Get user settings with brand and group settings applied"""
return user.group_attributes(self._context["request"]).get("settings", {}) return user.group_attributes(self._context["request"]).get("settings", {})
def get_system_permissions(self, user: User) -> list[str]: def get_system_permissions(self, user: User) -> list[str]:
@ -382,11 +381,11 @@ class UserViewSet(UsedByMixin, ModelViewSet):
return User.objects.all().exclude(pk=get_anonymous_user().pk) return User.objects.all().exclude(pk=get_anonymous_user().pk)
def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]: def _create_recovery_link(self) -> tuple[Optional[str], Optional[Token]]:
"""Create a recovery link (when the current tenant has a recovery flow set), """Create a recovery link (when the current brand has a recovery flow set),
that can either be shown to an admin or sent to the user directly""" that can either be shown to an admin or sent to the user directly"""
tenant: Tenant = self.request._request.tenant brand: Brand = self.request._request.brand
# Check that there is a recovery flow, if not return an error # Check that there is a recovery flow, if not return an error
flow = tenant.flow_recovery flow = brand.flow_recovery
if not flow: if not flow:
LOGGER.debug("No recovery flow set") LOGGER.debug("No recovery flow set")
return None, None return None, None
@ -618,7 +617,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
def impersonate(self, request: Request, pk: int) -> Response: def impersonate(self, request: Request, pk: int) -> Response:
"""Impersonate a user""" """Impersonate a user"""
if not CONFIG.get_bool("impersonation"): if not request.tenant.impersonation:
LOGGER.debug("User attempted to impersonate", user=request.user) LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401) return Response(status=401)
if not request.user.has_perm("impersonate"): if not request.user.has_perm("impersonate"):

View File

@ -13,18 +13,18 @@ class AuthentikCoreConfig(ManagedAppConfig):
mountpoint = "" mountpoint = ""
default = True default = True
def reconcile_load_core_signals(self): def reconcile_global_load_core_signals(self):
"""Load core signals""" """Load core signals"""
self.import_module("authentik.core.signals") self.import_module("authentik.core.signals")
def reconcile_debug_worker_hook(self): def reconcile_global_debug_worker_hook(self):
"""Dispatch startup tasks inline when debugging""" """Dispatch startup tasks inline when debugging"""
if settings.DEBUG: if settings.DEBUG:
from authentik.root.celery import worker_ready_hook from authentik.root.celery import worker_ready_hook
worker_ready_hook() worker_ready_hook()
def reconcile_source_inbuilt(self): def reconcile_tenant_source_inbuilt(self):
"""Reconcile inbuilt source""" """Reconcile inbuilt source"""
from authentik.core.models import Source from authentik.core.models import Source

View File

@ -1,13 +1,20 @@
"""Run bootstrap tasks""" """Run bootstrap tasks"""
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django_tenants.utils import get_public_schema_name
from authentik.root.celery import _get_startup_tasks from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant
from authentik.tenants.models import Tenant
class Command(BaseCommand): class Command(BaseCommand):
"""Run bootstrap tasks to ensure certain objects are created""" """Run bootstrap tasks to ensure certain objects are created"""
def handle(self, **options): def handle(self, **options):
tasks = _get_startup_tasks() for task in _get_startup_tasks_default_tenant():
for task in tasks: with Tenant.objects.get(schema_name=get_public_schema_name()):
task()
for task in _get_startup_tasks_all_tenants():
for tenant in Tenant.objects.filter(ready=True):
with tenant:
task() task()

View File

@ -4,6 +4,8 @@ from django.contrib.auth.management import create_permissions
from django.core.management.base import BaseCommand, no_translations from django.core.management.base import BaseCommand, no_translations
from guardian.management import create_anonymous_user from guardian.management import create_anonymous_user
from authentik.tenants.models import Tenant
class Command(BaseCommand): class Command(BaseCommand):
"""Repair missing permissions""" """Repair missing permissions"""
@ -11,6 +13,8 @@ class Command(BaseCommand):
@no_translations @no_translations
def handle(self, *args, **options): def handle(self, *args, **options):
"""Check permissions for all apps""" """Check permissions for all apps"""
for tenant in Tenant.objects.filter(ready=True):
with tenant:
for app in apps.get_app_configs(): for app in apps.get_app_configs():
self.stdout.write(f"Checking app {app.name} ({app.label})\n") self.stdout.write(f"Checking app {app.name} ({app.label})\n")
create_permissions(app, verbosity=0) create_permissions(app, verbosity=0)

View File

@ -201,8 +201,8 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""Get a dictionary containing the attributes from all groups the user belongs to, """Get a dictionary containing the attributes from all groups the user belongs to,
including the users attributes""" including the users attributes"""
final_attributes = {} final_attributes = {}
if request and hasattr(request, "tenant"): if request and hasattr(request, "brand"):
always_merger.merge(final_attributes, request.tenant.attributes) always_merger.merge(final_attributes, request.brand.attributes)
for group in self.all_groups().order_by("name"): for group in self.all_groups().order_by("name"):
always_merger.merge(final_attributes, group.attributes) always_merger.merge(final_attributes, group.attributes)
always_merger.merge(final_attributes, self.attributes) always_merger.merge(final_attributes, self.attributes)
@ -261,7 +261,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
except Exception as exc: except Exception as exc:
LOGGER.warning("Failed to get default locale", exc=exc) LOGGER.warning("Failed to get default locale", exc=exc)
if request: if request:
return request.tenant.locale return request.brand.locale
return "" return ""
@property @property

View File

@ -5,7 +5,7 @@
window.authentik = { window.authentik = {
locale: "{{ LANGUAGE_CODE }}", locale: "{{ LANGUAGE_CODE }}",
config: JSON.parse('{{ config_json|escapejs }}'), config: JSON.parse('{{ config_json|escapejs }}'),
tenant: JSON.parse('{{ tenant_json|escapejs }}'), brand: JSON.parse('{{ brand_json|escapejs }}'),
versionFamily: "{{ version_family }}", versionFamily: "{{ version_family }}",
versionSubdomain: "{{ version_subdomain }}", versionSubdomain: "{{ version_subdomain }}",
build: "{{ build }}", build: "{{ build }}",

View File

@ -7,9 +7,9 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1"> <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">
<title>{% block title %}{% trans title|default:tenant.branding_title %}{% endblock %}</title> <title>{% block title %}{% trans title|default:brand.branding_title %}{% endblock %}</title>
<link rel="icon" href="{{ tenant.branding_favicon }}"> <link rel="icon" href="{{ brand.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}"> <link rel="shortcut icon" href="{{ brand.branding_favicon }}">
{% block head_before %} {% block head_before %}
{% endblock %} {% endblock %}
<link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}"> <link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}">

View File

@ -4,7 +4,7 @@
{% load i18n %} {% load i18n %}
{% block title %} {% block title %}
{% trans 'End session' %} - {{ tenant.branding_title }} {% trans 'End session' %} - {{ brand.branding_title }}
{% endblock %} {% endblock %}
{% block card_title %} {% block card_title %}
@ -16,7 +16,7 @@ You've logged out of {{ application }}.
{% block card %} {% block card %}
<form method="POST" class="pf-c-form"> <form method="POST" class="pf-c-form">
<p> <p>
{% blocktrans with application=application.name branding_title=tenant.branding_title %} {% blocktrans with application=application.name branding_title=brand.branding_title %}
You've logged out of {{ application }}. You can go back to the overview to launch another application, or log out of your {{ branding_title }} account. You've logged out of {{ application }}. You can go back to the overview to launch another application, or log out of your {{ branding_title }} account.
{% endblocktrans %} {% endblocktrans %}
</p> </p>
@ -26,7 +26,7 @@ You've logged out of {{ application }}.
</a> </a>
<a id="logout" href="{% url 'authentik_flows:default-invalidation' %}" class="pf-c-button pf-m-secondary"> <a id="logout" href="{% url 'authentik_flows:default-invalidation' %}" class="pf-c-button pf-m-secondary">
{% blocktrans with branding_title=tenant.branding_title %} {% blocktrans with branding_title=brand.branding_title %}
Log out of {{ branding_title }} Log out of {{ branding_title }}
{% endblocktrans %} {% endblocktrans %}
</a> </a>

View File

@ -4,7 +4,7 @@
{% load i18n %} {% load i18n %}
{% block title %} {% block title %}
{{ tenant.branding_title }} {{ brand.branding_title }}
{% endblock %} {% endblock %}
{% block card_title %} {% block card_title %}

View File

@ -50,7 +50,7 @@
<div class="ak-login-container"> <div class="ak-login-container">
<main class="pf-c-login__main"> <main class="pf-c-login__main">
<div class="pf-c-login__main-header pf-c-brand ak-brand"> <div class="pf-c-login__main-header pf-c-brand ak-brand">
<img src="{{ tenant.branding_logo }}" alt="authentik Logo" /> <img src="{{ brand.branding_logo }}" alt="authentik Logo" />
</div> </div>
<header class="pf-c-login__main-header"> <header class="pf-c-login__main-header">
<h1 class="pf-c-title pf-m-3xl"> <h1 class="pf-c-title pf-m-3xl">

View File

@ -3,10 +3,10 @@ from unittest.mock import MagicMock, patch
from django.urls import reverse from django.urls import reverse
from authentik.brands.models import Brand
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
from authentik.flows.tests import FlowTestCase from authentik.flows.tests import FlowTestCase
from authentik.tenants.models import Tenant
class TestApplicationsViews(FlowTestCase): class TestApplicationsViews(FlowTestCase):
@ -21,9 +21,9 @@ class TestApplicationsViews(FlowTestCase):
def test_check_redirect(self): def test_check_redirect(self):
"""Test redirect""" """Test redirect"""
empty_flow = create_test_flow() empty_flow = create_test_flow()
tenant: Tenant = create_test_tenant() brand: Brand = create_test_brand()
tenant.flow_authentication = empty_flow brand.flow_authentication = empty_flow
tenant.save() brand.save()
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_core:application-launch", "authentik_core:application-launch",
@ -45,9 +45,9 @@ class TestApplicationsViews(FlowTestCase):
"""Test redirect""" """Test redirect"""
self.client.force_login(self.user) self.client.force_login(self.user)
empty_flow = create_test_flow() empty_flow = create_test_flow()
tenant: Tenant = create_test_tenant() brand: Brand = create_test_brand()
tenant.flow_authentication = empty_flow brand.flow_authentication = empty_flow
tenant.save() brand.save()
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_core:application-launch", "authentik_core:application-launch",

View File

@ -6,7 +6,7 @@ from rest_framework.test import APITestCase
from authentik.core.models import User from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.config import CONFIG from authentik.tenants.utils import get_current_tenant
class TestImpersonation(APITestCase): class TestImpersonation(APITestCase):
@ -56,9 +56,11 @@ class TestImpersonation(APITestCase):
response_body = loads(response.content.decode()) response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username) self.assertEqual(response_body["user"]["username"], self.other_user.username)
@CONFIG.patch("impersonation", False)
def test_impersonate_disabled(self): def test_impersonate_disabled(self):
"""test impersonation that is disabled""" """test impersonation that is disabled"""
tenant = get_current_tenant()
tenant.impersonation = False
tenant.save()
self.client.force_login(self.user) self.client.force_login(self.user)
response = self.client.post( response = self.client.post(

View File

@ -7,6 +7,7 @@ from django.core.cache import cache
from django.urls.base import reverse from django.urls.base import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.brands.models import Brand
from authentik.core.models import ( from authentik.core.models import (
USER_ATTRIBUTE_TOKEN_EXPIRING, USER_ATTRIBUTE_TOKEN_EXPIRING,
AuthenticatedSession, AuthenticatedSession,
@ -14,11 +15,10 @@ from authentik.core.models import (
User, User,
UserTypes, UserTypes,
) )
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
from authentik.flows.models import FlowDesignation from authentik.flows.models import FlowDesignation
from authentik.lib.generators import generate_id, generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
from authentik.tenants.models import Tenant
class TestUsersAPI(APITestCase): class TestUsersAPI(APITestCase):
@ -80,9 +80,9 @@ class TestUsersAPI(APITestCase):
def test_recovery(self): def test_recovery(self):
"""Test user recovery link (no recovery flow set)""" """Test user recovery link (no recovery flow set)"""
flow = create_test_flow(FlowDesignation.RECOVERY) flow = create_test_flow(FlowDesignation.RECOVERY)
tenant: Tenant = create_test_tenant() brand: Brand = create_test_brand()
tenant.flow_recovery = flow brand.flow_recovery = flow
tenant.save() brand.save()
self.client.force_login(self.admin) self.client.force_login(self.admin)
response = self.client.get( response = self.client.get(
reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk}) reverse("authentik_api:user-recovery", kwargs={"pk": self.user.pk})
@ -108,9 +108,9 @@ class TestUsersAPI(APITestCase):
self.user.email = "foo@bar.baz" self.user.email = "foo@bar.baz"
self.user.save() self.user.save()
flow = create_test_flow(designation=FlowDesignation.RECOVERY) flow = create_test_flow(designation=FlowDesignation.RECOVERY)
tenant: Tenant = create_test_tenant() brand: Brand = create_test_brand()
tenant.flow_recovery = flow brand.flow_recovery = flow
tenant.save() brand.save()
self.client.force_login(self.admin) self.client.force_login(self.admin)
response = self.client.get( response = self.client.get(
reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk}) reverse("authentik_api:user-recovery-email", kwargs={"pk": self.user.pk})
@ -122,9 +122,9 @@ class TestUsersAPI(APITestCase):
self.user.email = "foo@bar.baz" self.user.email = "foo@bar.baz"
self.user.save() self.user.save()
flow = create_test_flow(FlowDesignation.RECOVERY) flow = create_test_flow(FlowDesignation.RECOVERY)
tenant: Tenant = create_test_tenant() brand: Brand = create_test_brand()
tenant.flow_recovery = flow brand.flow_recovery = flow
tenant.save() brand.save()
stage = EmailStage.objects.create(name="email") stage = EmailStage.objects.create(name="email")

View File

@ -8,6 +8,7 @@ from rest_framework.test import APITestCase
from authentik.core.models import User from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.tenants.utils import get_current_tenant
class TestUsersAvatars(APITestCase): class TestUsersAvatars(APITestCase):
@ -17,18 +18,25 @@ class TestUsersAvatars(APITestCase):
self.admin = create_test_admin_user() self.admin = create_test_admin_user()
self.user = User.objects.create(username="test-user") self.user = User.objects.create(username="test-user")
def set_avatar_mode(self, mode: str):
"""Set the avatar mode on the current tenant."""
tenant = get_current_tenant()
tenant.avatars = mode
tenant.save()
@CONFIG.patch("avatars", "none") @CONFIG.patch("avatars", "none")
def test_avatars_none(self): def test_avatars_none(self):
"""Test avatars none""" """Test avatars none"""
self.set_avatar_mode("none")
self.client.force_login(self.admin) self.client.force_login(self.admin)
response = self.client.get(reverse("authentik_api:user-me")) response = self.client.get(reverse("authentik_api:user-me"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertEqual(body["user"]["avatar"], "/static/dist/assets/images/user_default.png") self.assertEqual(body["user"]["avatar"], "/static/dist/assets/images/user_default.png")
@CONFIG.patch("avatars", "gravatar")
def test_avatars_gravatar(self): def test_avatars_gravatar(self):
"""Test avatars gravatar""" """Test avatars gravatar"""
self.set_avatar_mode("gravatar")
self.admin.email = "static@t.goauthentik.io" self.admin.email = "static@t.goauthentik.io"
self.admin.save() self.admin.save()
self.client.force_login(self.admin) self.client.force_login(self.admin)
@ -45,27 +53,27 @@ class TestUsersAvatars(APITestCase):
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertIn("gravatar", body["user"]["avatar"]) self.assertIn("gravatar", body["user"]["avatar"])
@CONFIG.patch("avatars", "initials")
def test_avatars_initials(self): def test_avatars_initials(self):
"""Test avatars initials""" """Test avatars initials"""
self.set_avatar_mode("initials")
self.client.force_login(self.admin) self.client.force_login(self.admin)
response = self.client.get(reverse("authentik_api:user-me")) response = self.client.get(reverse("authentik_api:user-me"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertIn("data:image/svg+xml;base64,", body["user"]["avatar"]) self.assertIn("data:image/svg+xml;base64,", body["user"]["avatar"])
@CONFIG.patch("avatars", "foo://%(username)s")
def test_avatars_custom(self): def test_avatars_custom(self):
"""Test avatars custom""" """Test avatars custom"""
self.set_avatar_mode("foo://%(username)s")
self.client.force_login(self.admin) self.client.force_login(self.admin)
response = self.client.get(reverse("authentik_api:user-me")) response = self.client.get(reverse("authentik_api:user-me"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertEqual(body["user"]["avatar"], f"foo://{self.admin.username}") self.assertEqual(body["user"]["avatar"], f"foo://{self.admin.username}")
@CONFIG.patch("avatars", "attributes.foo.avatar")
def test_avatars_attributes(self): def test_avatars_attributes(self):
"""Test avatars attributes""" """Test avatars attributes"""
self.set_avatar_mode("attributes.foo.avatar")
self.admin.attributes = {"foo": {"avatar": "bar"}} self.admin.attributes = {"foo": {"avatar": "bar"}}
self.admin.save() self.admin.save()
self.client.force_login(self.admin) self.client.force_login(self.admin)
@ -74,9 +82,9 @@ class TestUsersAvatars(APITestCase):
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertEqual(body["user"]["avatar"], "bar") self.assertEqual(body["user"]["avatar"], "bar")
@CONFIG.patch("avatars", "attributes.foo.avatar,initials")
def test_avatars_fallback(self): def test_avatars_fallback(self):
"""Test fallback""" """Test fallback"""
self.set_avatar_mode("attributes.foo.avatar,initials")
self.client.force_login(self.admin) self.client.force_login(self.admin)
response = self.client.get(reverse("authentik_api:user-me")) response = self.client.get(reverse("authentik_api:user-me"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)

View File

@ -3,12 +3,12 @@ from typing import Optional
from django.utils.text import slugify from django.utils.text import slugify
from authentik.brands.models import Brand
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.crypto.builder import CertificateBuilder from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.flows.models import Flow, FlowDesignation from authentik.flows.models import Flow, FlowDesignation
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.tenants.models import Tenant
def create_test_flow( def create_test_flow(
@ -43,12 +43,12 @@ def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User:
return user return user
def create_test_tenant(**kwargs) -> Tenant: def create_test_brand(**kwargs) -> Brand:
"""Generate a test tenant, removing all other tenants to make sure this one """Generate a test brand, removing all other brands to make sure this one
matches.""" matches."""
uid = generate_id(20) uid = generate_id(20)
Tenant.objects.all().delete() Brand.objects.all().delete()
return Tenant.objects.create(domain=uid, default=True, **kwargs) return Brand.objects.create(domain=uid, default=True, **kwargs)
def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair: def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair:

View File

@ -9,8 +9,8 @@ from rest_framework.request import Request
from authentik import get_build_hash from authentik import get_build_hash
from authentik.admin.tasks import LOCAL_VERSION from authentik.admin.tasks import LOCAL_VERSION
from authentik.api.v3.config import ConfigView from authentik.api.v3.config import ConfigView
from authentik.brands.api import CurrentBrandSerializer
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.tenants.api import CurrentTenantSerializer
class InterfaceView(TemplateView): class InterfaceView(TemplateView):
@ -18,7 +18,7 @@ class InterfaceView(TemplateView):
def get_context_data(self, **kwargs: Any) -> dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs["config_json"] = dumps(ConfigView(request=Request(self.request)).get_config().data) kwargs["config_json"] = dumps(ConfigView(request=Request(self.request)).get_config().data)
kwargs["tenant_json"] = dumps(CurrentTenantSerializer(self.request.tenant).data) kwargs["brand_json"] = dumps(CurrentBrandSerializer(self.request.brand).data)
kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}" kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}"
kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}" kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}"
kwargs["build"] = get_build_hash() kwargs["build"] = get_build_hash()

View File

@ -16,7 +16,7 @@ class AuthentikCryptoConfig(ManagedAppConfig):
verbose_name = "authentik Crypto" verbose_name = "authentik Crypto"
default = True default = True
def reconcile_load_crypto_tasks(self): def reconcile_global_load_crypto_tasks(self):
"""Load crypto tasks""" """Load crypto tasks"""
self.import_module("authentik.crypto.tasks") self.import_module("authentik.crypto.tasks")
@ -39,7 +39,7 @@ class AuthentikCryptoConfig(ManagedAppConfig):
}, },
) )
def reconcile_managed_jwt_cert(self): def reconcile_tenant_managed_jwt_cert(self):
"""Ensure managed JWT certificate""" """Ensure managed JWT certificate"""
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
@ -52,7 +52,7 @@ class AuthentikCryptoConfig(ManagedAppConfig):
): ):
self._create_update_cert() self._create_update_cert()
def reconcile_self_signed(self): def reconcile_tenant_self_signed(self):
"""Create self-signed keypair""" """Create self-signed keypair"""
from authentik.crypto.builder import CertificateBuilder from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair

View File

@ -1,21 +1,22 @@
"""Import certificate""" """Import certificate"""
from sys import exit as sys_exit from sys import exit as sys_exit
from django.core.management.base import BaseCommand, no_translations from django.core.management.base import no_translations
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.crypto.api import CertificateKeyPairSerializer from authentik.crypto.api import CertificateKeyPairSerializer
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.tenants.management import TenantCommand
LOGGER = get_logger() LOGGER = get_logger()
class Command(BaseCommand): class Command(TenantCommand):
"""Import certificate""" """Import certificate"""
@no_translations @no_translations
def handle(self, *args, **options): def handle_per_tenant(self, *args, **options):
"""Import certificate""" """Import certificate"""
keypair = CertificateKeyPair.objects.filter(name=options["name"]).first() keypair = CertificateKeyPair.objects.filter(name=options["name"]).first()
dirty = False dirty = False

View File

@ -1,4 +1,8 @@
"""Enterprise app config""" """Enterprise app config"""
from functools import lru_cache
from django.conf import settings
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
@ -14,6 +18,17 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
verbose_name = "authentik Enterprise" verbose_name = "authentik Enterprise"
default = True default = True
def reconcile_load_enterprise_signals(self): def reconcile_global_load_enterprise_signals(self):
"""Load enterprise signals""" """Load enterprise signals"""
self.import_module("authentik.enterprise.signals") self.import_module("authentik.enterprise.signals")
def enabled(self):
"""Return true if enterprise is enabled and valid"""
return self.check_enabled() or settings.TEST
@lru_cache()
def check_enabled(self):
"""Actual enterprise check, cached"""
from authentik.enterprise.models import LicenseKey
return LicenseKey.get_total().is_valid()

View File

@ -12,6 +12,6 @@ class AuthentikEnterpriseProviderRAC(EnterpriseConfig):
mountpoint = "" mountpoint = ""
ws_mountpoint = "authentik.enterprise.providers.rac.urls" ws_mountpoint = "authentik.enterprise.providers.rac.urls"
def reconcile_load_rac_signals(self): def reconcile_global_load_rac_signals(self):
"""Load rac signals""" """Load rac signals"""
self.import_module("authentik.enterprise.providers.rac.signals") self.import_module("authentik.enterprise.providers.rac.signals")

View File

@ -11,6 +11,6 @@ CELERY_BEAT_SCHEDULE = {
} }
} }
INSTALLED_APPS = [ TENANT_APPS = [
"authentik.enterprise.providers.rac", "authentik.enterprise.providers.rac",
] ]

View File

@ -36,7 +36,7 @@ class EventSerializer(ModelSerializer):
"client_ip", "client_ip",
"created", "created",
"expires", "expires",
"tenant", "brand",
] ]
@ -77,10 +77,10 @@ class EventsFilter(django_filters.FilterSet):
field_name="action", field_name="action",
lookup_expr="icontains", lookup_expr="icontains",
) )
tenant_name = django_filters.CharFilter( brand_name = django_filters.CharFilter(
field_name="tenant", field_name="brand",
lookup_expr="name", lookup_expr="name",
label="Tenant name", label="Brand name",
) )
def filter_context_model_pk(self, queryset, name, value): def filter_context_model_pk(self, queryset, name, value):

View File

@ -7,7 +7,7 @@ from authentik.lib.config import CONFIG, ENV_PREFIX
GAUGE_TASKS = Gauge( GAUGE_TASKS = Gauge(
"authentik_system_tasks", "authentik_system_tasks",
"System tasks and their status", "System tasks and their status",
["task_name", "task_uid", "status"], ["tenant", "task_name", "task_uid", "status"],
) )
@ -19,11 +19,11 @@ class AuthentikEventsConfig(ManagedAppConfig):
verbose_name = "authentik Events" verbose_name = "authentik Events"
default = True default = True
def reconcile_load_events_signals(self): def reconcile_global_load_events_signals(self):
"""Load events signals""" """Load events signals"""
self.import_module("authentik.events.signals") self.import_module("authentik.events.signals")
def reconcile_check_deprecations(self): def reconcile_global_check_deprecations(self):
"""Check for config deprecations""" """Check for config deprecations"""
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction

View File

@ -305,7 +305,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="event", model_name="event",
name="tenant", name="tenant",
field=models.JSONField(blank=True, default=authentik.events.models.default_tenant), field=models.JSONField(blank=True, default=authentik.events.models.default_brand),
), ),
migrations.AlterField( migrations.AlterField(
model_name="event", model_name="event",

View File

@ -0,0 +1,17 @@
# Generated by Django 4.2.7 on 2023-11-06 18:58
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0002_alter_notificationtransport_mode"),
]
operations = [
migrations.RenameField(
model_name="event",
old_name="tenant",
new_name="brand",
),
]

View File

@ -21,6 +21,8 @@ from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import get_full_version from authentik import get_full_version
from authentik.brands.models import Brand
from authentik.brands.utils import DEFAULT_BRAND
from authentik.core.middleware import ( from authentik.core.middleware import (
SESSION_KEY_IMPERSONATE_ORIGINAL_USER, SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
SESSION_KEY_IMPERSONATE_USER, SESSION_KEY_IMPERSONATE_USER,
@ -42,7 +44,6 @@ from authentik.policies.models import PolicyBindingModel
from authentik.root.middleware import ClientIPMiddleware from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
from authentik.tenants.utils import DEFAULT_TENANT
LOGGER = get_logger() LOGGER = get_logger()
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,13 +52,13 @@ if TYPE_CHECKING:
def default_event_duration(): def default_event_duration():
"""Default duration an Event is saved. """Default duration an Event is saved.
This is used as a fallback when no tenant is available""" This is used as a fallback when no brand is available"""
return now() + timedelta(days=365) return now() + timedelta(days=365)
def default_tenant(): def default_brand():
"""Get a default value for tenant""" """Get a default value for brand"""
return sanitize_dict(model_to_dict(DEFAULT_TENANT)) return sanitize_dict(model_to_dict(DEFAULT_BRAND))
class NotificationTransportError(SentryIgnoredException): class NotificationTransportError(SentryIgnoredException):
@ -171,7 +172,7 @@ class Event(SerializerModel, ExpiringModel):
context = models.JSONField(default=dict, blank=True) context = models.JSONField(default=dict, blank=True)
client_ip = models.GenericIPAddressField(null=True) client_ip = models.GenericIPAddressField(null=True)
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
tenant = models.JSONField(default=default_tenant, blank=True) brand = models.JSONField(default=default_brand, blank=True)
# Shadow the expires attribute from ExpiringModel to override the default duration # Shadow the expires attribute from ExpiringModel to override the default duration
expires = models.DateTimeField(default=default_event_duration) expires = models.DateTimeField(default=default_event_duration)
@ -231,7 +232,9 @@ class Event(SerializerModel, ExpiringModel):
# hence we set self.created to now and then use it # hence we set self.created to now and then use it
self.created = now() self.created = now()
self.expires = self.created + timedelta_from_string(tenant.event_retention) self.expires = self.created + timedelta_from_string(tenant.event_retention)
self.tenant = sanitize_dict(model_to_dict(tenant)) if hasattr(request, "brand"):
brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand))
if hasattr(request, "user"): if hasattr(request, "user"):
original_user = None original_user = None
if hasattr(request, "session"): if hasattr(request, "session"):

View File

@ -5,10 +5,11 @@ from enum import Enum
from timeit import default_timer from timeit import default_timer
from typing import Any, Optional from typing import Any, Optional
from celery import Task
from django.core.cache import cache from django.core.cache import cache
from django.db import connection
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from tenant_schemas_celery.task import TenantTask
from authentik.events.apps import GAUGE_TASKS from authentik.events.apps import GAUGE_TASKS
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
@ -101,6 +102,7 @@ class TaskInfo:
except TypeError: except TypeError:
duration = 0 duration = 0
GAUGE_TASKS.labels( GAUGE_TASKS.labels(
tenant=connection.schema_name,
task_name=self.task_name.split(":")[0], task_name=self.task_name.split(":")[0],
task_uid=self.result.uid or "", task_uid=self.result.uid or "",
status=self.result.status.name.lower(), status=self.result.status.name.lower(),
@ -112,7 +114,7 @@ class TaskInfo:
cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60) cache.set(self.full_name, self, timeout=timeout_hours * 60 * 60)
class MonitoredTask(Task): class MonitoredTask(TenantTask):
"""Task which can save its state to the cache""" """Task which can save its state to the cache"""
# For tasks that should only be listed if they failed, set this to False # For tasks that should only be listed if they failed, set this to False

View File

@ -13,11 +13,11 @@ from authentik.events.tasks import event_notification_handler, gdpr_cleanup
from authentik.flows.models import Stage from authentik.flows.models import Stage
from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.config import CONFIG
from authentik.stages.invitation.models import Invitation from authentik.stages.invitation.models import Invitation
from authentik.stages.invitation.signals import invitation_used from authentik.stages.invitation.signals import invitation_used
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
from authentik.stages.user_write.signals import user_write from authentik.stages.user_write.signals import user_write
from authentik.tenants.utils import get_current_tenant
SESSION_LOGIN_EVENT = "login_event" SESSION_LOGIN_EVENT = "login_event"
@ -98,5 +98,5 @@ def event_post_save_notification(sender, instance: Event, **_):
@receiver(pre_delete, sender=User) @receiver(pre_delete, sender=User)
def event_user_pre_delete_cleanup(sender, instance: User, **_): def event_user_pre_delete_cleanup(sender, instance: User, **_):
"""If gdpr_compliance is enabled, remove all the user's events""" """If gdpr_compliance is enabled, remove all the user's events"""
if CONFIG.get_bool("gdpr_compliance", True): if get_current_tenant().gdpr_compliance:
gdpr_cleanup.delay(instance.pk) gdpr_cleanup.delay(instance.pk)

View File

@ -6,12 +6,12 @@ from django.test import RequestFactory, TestCase
from django.views.debug import SafeExceptionReporterFilter from django.views.debug import SafeExceptionReporterFilter
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from authentik.brands.models import Brand
from authentik.core.models import Group from authentik.core.models import Group
from authentik.events.models import Event from authentik.events.models import Event
from authentik.flows.views.executor import QS_QUERY from authentik.flows.views.executor import QS_QUERY
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
from authentik.tenants.models import Tenant
class TestEvents(TestCase): class TestEvents(TestCase):
@ -97,19 +97,19 @@ class TestEvents(TestCase):
}, },
) )
def test_from_http_tenant(self): def test_from_http_brand(self):
"""Test from_http tenant""" """Test from_http brand"""
# Test tenant # Test brand
request = self.factory.get("/") request = self.factory.get("/")
tenant = Tenant(domain="test-tenant") brand = Brand(domain="test-brand")
setattr(request, "tenant", tenant) setattr(request, "brand", brand)
event = Event.new("unittest").from_http(request) event = Event.new("unittest").from_http(request)
self.assertEqual( self.assertEqual(
event.tenant, event.brand,
{ {
"app": "authentik_tenants", "app": "authentik_brands",
"model_name": "tenant", "model_name": "brand",
"name": "Tenant test-tenant", "name": "Brand test-brand",
"pk": tenant.pk.hex, "pk": brand.pk.hex,
}, },
) )

View File

@ -72,10 +72,13 @@ def model_to_dict(model: Model) -> dict[str, Any]:
} }
def get_user(user: User, original_user: Optional[User] = None) -> dict[str, Any]: def get_user(user: User | AnonymousUser, original_user: Optional[User] = None) -> dict[str, Any]:
"""Convert user object to dictionary, optionally including the original user""" """Convert user object to dictionary, optionally including the original user"""
if isinstance(user, AnonymousUser): if isinstance(user, AnonymousUser):
try:
user = get_anonymous_user() user = get_anonymous_user()
except User.DoesNotExist:
return {}
user_data = { user_data = {
"username": user.username, "username": user.username,
"pk": user.pk, "pk": user.pk,

View File

@ -7,6 +7,7 @@ from authentik.lib.utils.reflection import all_subclasses
GAUGE_FLOWS_CACHED = Gauge( GAUGE_FLOWS_CACHED = Gauge(
"authentik_flows_cached", "authentik_flows_cached",
"Cached flows", "Cached flows",
["tenant"],
) )
HIST_FLOW_EXECUTION_STAGE_TIME = Histogram( HIST_FLOW_EXECUTION_STAGE_TIME = Histogram(
"authentik_flows_execution_stage_time", "authentik_flows_execution_stage_time",
@ -29,11 +30,11 @@ class AuthentikFlowsConfig(ManagedAppConfig):
verbose_name = "authentik Flows" verbose_name = "authentik Flows"
default = True default = True
def reconcile_load_flows_signals(self): def reconcile_global_load_flows_signals(self):
"""Load flows signals""" """Load flows signals"""
self.import_module("authentik.flows.signals") self.import_module("authentik.flows.signals")
def reconcile_load_stages(self): def reconcile_global_load_stages(self):
"""Ensure all stages are loaded""" """Ensure all stages are loaded"""
from authentik.flows.models import Stage from authentik.flows.models import Stage

View File

@ -1,5 +1,6 @@
"""authentik flow signals""" """authentik flow signals"""
from django.core.cache import cache from django.core.cache import cache
from django.db import connection
from django.db.models.signals import post_save, pre_delete from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver from django.dispatch import receiver
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -21,7 +22,9 @@ def delete_cache_prefix(prefix: str) -> int:
@receiver(monitoring_set) @receiver(monitoring_set)
def monitoring_set_flows(sender, **kwargs): def monitoring_set_flows(sender, **kwargs):
"""set flow gauges""" """set flow gauges"""
GAUGE_FLOWS_CACHED.set(len(cache.keys(f"{CACHE_PREFIX}*") or [])) GAUGE_FLOWS_CACHED.labels(tenant=connection.schema_name).set(
len(cache.keys(f"{CACHE_PREFIX}*") or [])
)
@receiver(post_save) @receiver(post_save)

View File

@ -22,6 +22,7 @@ from sentry_sdk.api import set_tag
from sentry_sdk.hub import Hub from sentry_sdk.hub import Hub
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.brands.models import Brand
from authentik.core.models import Application from authentik.core.models import Application
from authentik.events.models import Event, EventAction, cleanse_dict from authentik.events.models import Event, EventAction, cleanse_dict
from authentik.flows.apps import HIST_FLOW_EXECUTION_STAGE_TIME from authentik.flows.apps import HIST_FLOW_EXECUTION_STAGE_TIME
@ -60,7 +61,6 @@ from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import all_subclasses, class_to_path from authentik.lib.utils.reflection import all_subclasses, class_to_path
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
# Argument used to redirect user after login # Argument used to redirect user after login
@ -490,11 +490,11 @@ class ToDefaultFlow(View):
def get_flow(self) -> Flow: def get_flow(self) -> Flow:
"""Get a flow for the selected designation""" """Get a flow for the selected designation"""
tenant: Tenant = self.request.tenant brand: Brand = self.request.brand
flow = None flow = None
# First, attempt to get default flow from tenant # First, attempt to get default flow from brand
if self.designation == FlowDesignation.AUTHENTICATION: if self.designation == FlowDesignation.AUTHENTICATION:
flow = tenant.flow_authentication flow = brand.flow_authentication
# Check if we have a default flow from application # Check if we have a default flow from application
application: Optional[Application] = self.request.session.get( application: Optional[Application] = self.request.session.get(
SESSION_KEY_APPLICATION_PRE SESSION_KEY_APPLICATION_PRE
@ -502,7 +502,7 @@ class ToDefaultFlow(View):
if application and application.provider and application.provider.authentication_flow: if application and application.provider and application.provider.authentication_flow:
flow = application.provider.authentication_flow flow = application.provider.authentication_flow
elif self.designation == FlowDesignation.INVALIDATION: elif self.designation == FlowDesignation.INVALIDATION:
flow = tenant.flow_invalidation flow = brand.flow_invalidation
if flow: if flow:
return flow return flow
# If no flow was set, get the first based on slug and policy # If no flow was set, get the first based on slug and policy

View File

@ -11,8 +11,9 @@ from lxml import etree # nosec
from lxml.etree import Element, SubElement # nosec from lxml.etree import Element, SubElement # nosec
from requests.exceptions import RequestException from requests.exceptions import RequestException
from authentik.lib.config import CONFIG, get_path_from_dict from authentik.lib.config import get_path_from_dict
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.tenants.utils import get_current_tenant
GRAVATAR_URL = "https://secure.gravatar.com" GRAVATAR_URL = "https://secure.gravatar.com"
DEFAULT_AVATAR = static("dist/assets/images/user_default.png") DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
@ -183,7 +184,7 @@ def get_avatar(user: "User") -> str:
"initials": avatar_mode_generated, "initials": avatar_mode_generated,
"gravatar": avatar_mode_gravatar, "gravatar": avatar_mode_gravatar,
} }
modes: str = CONFIG.get("avatars", "none") modes: str = get_current_tenant().avatars
for mode in modes.split(","): for mode in modes.split(","):
avatar = None avatar = None
if mode in mode_map: if mode in mode_map:

View File

@ -34,6 +34,7 @@ REDIS_ENV_KEYS = [
f"{ENV_PREFIX}_REDIS__TLS_REQS", f"{ENV_PREFIX}_REDIS__TLS_REQS",
] ]
# Old key -> new key
DEPRECATIONS = { DEPRECATIONS = {
"geoip": "events.context_processors.geoip", "geoip": "events.context_processors.geoip",
"redis.broker_url": "broker.url", "redis.broker_url": "broker.url",
@ -201,12 +202,13 @@ class ConfigLoader:
root[key] = value root[key] = value
return root return root
def refresh(self, key: str): def refresh(self, key: str, default=None, sep=".") -> Any:
"""Update a single value""" """Update a single value"""
attr: Attr = get_path_from_dict(self.raw, key) attr: Attr = get_path_from_dict(self.raw, key, sep=sep, default=Attr(default))
if attr.source_type != Attr.Source.URI: if attr.source_type != Attr.Source.URI:
return return attr.value
attr.value = self.parse_uri(attr.source).value attr.value = self.parse_uri(attr.source).value
return attr.value
def parse_uri(self, value: str) -> Attr: def parse_uri(self, value: str) -> Attr:
"""Parse string values which start with a URI""" """Parse string values which start with a URI"""

View File

@ -37,8 +37,8 @@ redis:
tls_reqs: "none" tls_reqs: "none"
# broker: # broker:
# url: "" # url: ""
# transport_options: "" # transport_options: ""
cache: cache:
# url: "" # url: ""
@ -48,13 +48,10 @@ cache:
timeout_reputation: 300 timeout_reputation: 300
# channel: # channel:
# url: "" # url: ""
# result_backend: # result_backend:
# url: "" # url: ""
paths:
media: ./media
debug: false debug: false
remote_debug: false remote_debug: false
@ -107,22 +104,17 @@ reputation:
cookie_domain: null cookie_domain: null
disable_update_check: false disable_update_check: false
disable_startup_analytics: false disable_startup_analytics: false
avatars: env://AUTHENTIK_AUTHENTIK__AVATARS?gravatar,initials
events: events:
context_processors: context_processors:
geoip: "/geoip/GeoLite2-City.mmdb" geoip: "/geoip/GeoLite2-City.mmdb"
asn: "/geoip/GeoLite2-ASN.mmdb" asn: "/geoip/GeoLite2-ASN.mmdb"
footer_links: []
default_user_change_name: true
default_user_change_email: false
default_user_change_username: false
gdpr_compliance: true
cert_discovery_dir: /certs cert_discovery_dir: /certs
default_token_length: 60 default_token_length: 60
impersonation: true
tenants:
enabled: false
api_key: ""
blueprints_dir: /blueprints blueprints_dir: /blueprints
@ -133,3 +125,20 @@ web:
worker: worker:
concurrency: 2 concurrency: 2
storage:
media:
backend: file # or s3
file:
path: ./media
s3:
# How to talk to s3
# region: "us-east-1"
# use_ssl: True
# endpoint: "https://s3.us-east-1.amazonaws.com"
# access_key: ""
# secret_key: ""
# bucket_name: "authentik-media"
# How to render file URLs
# custom_domain: null
secure_urls: True

View File

@ -180,6 +180,11 @@ class BaseEvaluator:
full_expression += f"\nresult = handler({handler_signature})" full_expression += f"\nresult = handler({handler_signature})"
return full_expression return full_expression
def compile(self, expression: str) -> Any:
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
param_keys = self._context.keys()
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
def evaluate(self, expression_source: str) -> Any: def evaluate(self, expression_source: str) -> Any:
"""Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
If any exception is raised during execution, it is raised. If any exception is raised during execution, it is raised.
@ -188,13 +193,8 @@ class BaseEvaluator:
span: Span span: Span
span.description = self._filename span.description = self._filename
span.set_data("expression", expression_source) span.set_data("expression", expression_source)
param_keys = self._context.keys()
try: try:
ast_obj = compile( ast_obj = self.compile(expression_source)
self.wrap_expression(expression_source, param_keys),
self._filename,
"exec",
)
except (SyntaxError, ValueError) as exc: except (SyntaxError, ValueError) as exc:
self.handle_error(exc, expression_source) self.handle_error(exc, expression_source)
raise exc raise exc
@ -221,13 +221,8 @@ class BaseEvaluator:
def validate(self, expression: str) -> bool: def validate(self, expression: str) -> bool:
"""Validate expression's syntax, raise ValidationError if Syntax is invalid""" """Validate expression's syntax, raise ValidationError if Syntax is invalid"""
param_keys = self._context.keys()
try: try:
compile( self.compile(expression)
self.wrap_expression(expression, param_keys),
self._filename,
"exec",
)
return True return True
except (ValueError, SyntaxError) as exc: except (ValueError, SyntaxError) as exc:
raise ValidationError(f"Expression Syntax Error: {str(exc)}") from exc raise ValidationError(f"Expression Syntax Error: {str(exc)}") from exc

View File

@ -4,6 +4,7 @@ from logging import Logger
from os import getpid from os import getpid
import structlog import structlog
from django.db import connection
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -37,6 +38,7 @@ def structlog_configure():
structlog.stdlib.add_logger_name, structlog.stdlib.add_logger_name,
structlog.contextvars.merge_contextvars, structlog.contextvars.merge_contextvars,
add_process_id, add_process_id,
add_tenant_information,
structlog.stdlib.PositionalArgumentsFormatter(), structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso", utc=False), structlog.processors.TimeStamper(fmt="iso", utc=False),
structlog.processors.StackInfoRenderer(), structlog.processors.StackInfoRenderer(),
@ -111,3 +113,15 @@ def add_process_id(logger: Logger, method_name: str, event_dict):
"""Add the current process ID""" """Add the current process ID"""
event_dict["pid"] = getpid() event_dict["pid"] = getpid()
return event_dict return event_dict
def add_tenant_information(logger: Logger, method_name: str, event_dict):
"""Add the current tenant"""
tenant = getattr(connection, "tenant", None)
schema_name = getattr(connection, "schema_name", None)
if tenant is not None:
event_dict["schema_name"] = tenant.schema_name
event_dict["domain_url"] = getattr(tenant, "domain_url", None)
elif schema_name is not None:
event_dict["schema_name"] = schema_name
return event_dict

View File

@ -3,16 +3,19 @@ from prometheus_client import Gauge
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.config import CONFIG
LOGGER = get_logger() LOGGER = get_logger()
GAUGE_OUTPOSTS_CONNECTED = Gauge( GAUGE_OUTPOSTS_CONNECTED = Gauge(
"authentik_outposts_connected", "Currently connected outposts", ["outpost", "uid", "expected"] "authentik_outposts_connected",
"Currently connected outposts",
["tenant", "outpost", "uid", "expected"],
) )
GAUGE_OUTPOSTS_LAST_UPDATE = Gauge( GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
"authentik_outposts_last_update", "authentik_outposts_last_update",
"Last update from any outpost", "Last update from any outpost",
["outpost", "uid", "version"], ["tenant", "outpost", "uid", "version"],
) )
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded" MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost" MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
@ -26,11 +29,11 @@ class AuthentikOutpostConfig(ManagedAppConfig):
verbose_name = "authentik Outpost" verbose_name = "authentik Outpost"
default = True default = True
def reconcile_load_outposts_signals(self): def reconcile_global_load_outposts_signals(self):
"""Load outposts signals""" """Load outposts signals"""
self.import_module("authentik.outposts.signals") self.import_module("authentik.outposts.signals")
def reconcile_embedded_outpost(self): def reconcile_tenant_embedded_outpost(self):
"""Ensure embedded outpost""" """Ensure embedded outpost"""
from authentik.outposts.models import ( from authentik.outposts.models import (
DockerServiceConnection, DockerServiceConnection,
@ -39,6 +42,7 @@ class AuthentikOutpostConfig(ManagedAppConfig):
OutpostType, OutpostType,
) )
if not CONFIG.get_bool("outposts.disable_embedded_outpost", False):
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first(): if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
outpost.managed = MANAGED_OUTPOST outpost.managed = MANAGED_OUTPOST
outpost.save() outpost.save()
@ -56,3 +60,5 @@ class AuthentikOutpostConfig(ManagedAppConfig):
elif DockerServiceConnection.objects.exists(): elif DockerServiceConnection.objects.exists():
outpost.service_connection = DockerServiceConnection.objects.first() outpost.service_connection = DockerServiceConnection.objects.first()
outpost.save() outpost.save()
else:
Outpost.objects.filter(managed=MANAGED_OUTPOST).delete()

View File

@ -9,6 +9,7 @@ from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer from channels.generic.websocket import JsonWebsocketConsumer
from dacite.core import from_dict from dacite.core import from_dict
from dacite.data import Data from dacite.data import Data
from django.db import connection
from django.http.request import QueryDict from django.http.request import QueryDict
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
@ -82,6 +83,7 @@ class OutpostConsumer(JsonWebsocketConsumer):
self.channel_name, self.channel_name,
) )
GAUGE_OUTPOSTS_CONNECTED.labels( GAUGE_OUTPOSTS_CONNECTED.labels(
tenant=connection.schema_name,
outpost=self.outpost.name, outpost=self.outpost.name,
uid=self.instance_uid, uid=self.instance_uid,
expected=self.outpost.config.kubernetes_replicas, expected=self.outpost.config.kubernetes_replicas,
@ -100,6 +102,7 @@ class OutpostConsumer(JsonWebsocketConsumer):
) )
if self.outpost and self.instance_uid: if self.outpost and self.instance_uid:
GAUGE_OUTPOSTS_CONNECTED.labels( GAUGE_OUTPOSTS_CONNECTED.labels(
tenant=connection.schema_name,
outpost=self.outpost.name, outpost=self.outpost.name,
uid=self.instance_uid, uid=self.instance_uid,
expected=self.outpost.config.kubernetes_replicas, expected=self.outpost.config.kubernetes_replicas,
@ -121,6 +124,7 @@ class OutpostConsumer(JsonWebsocketConsumer):
elif msg.instruction == WebsocketMessageInstruction.ACK: elif msg.instruction == WebsocketMessageInstruction.ACK:
return return
GAUGE_OUTPOSTS_LAST_UPDATE.labels( GAUGE_OUTPOSTS_LAST_UPDATE.labels(
tenant=connection.schema_name,
outpost=self.outpost.name, outpost=self.outpost.name,
uid=self.instance_uid or "", uid=self.instance_uid or "",
version=state.version or "", version=state.version or "",

View File

@ -19,6 +19,7 @@ from structlog.stdlib import get_logger
from authentik import __version__, get_build_hash from authentik import __version__, get_build_hash
from authentik.blueprints.models import ManagedModel from authentik.blueprints.models import ManagedModel
from authentik.brands.models import Brand
from authentik.core.models import ( from authentik.core.models import (
USER_PATH_SYSTEM_PREFIX, USER_PATH_SYSTEM_PREFIX,
Provider, Provider,
@ -34,7 +35,6 @@ from authentik.lib.models import InheritanceForeignKey, SerializerModel
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.outposts.controllers.k8s.utils import get_namespace from authentik.outposts.controllers.k8s.utils import get_namespace
from authentik.tenants.models import Tenant
OUR_VERSION = parse(__version__) OUR_VERSION = parse(__version__)
OUTPOST_HELLO_INTERVAL = 10 OUTPOST_HELLO_INTERVAL = 10
@ -408,9 +408,9 @@ class Outpost(SerializerModel, ManagedModel):
else: else:
objects.append(provider) objects.append(provider)
if self.managed: if self.managed:
for tenant in Tenant.objects.filter(web_certificate__isnull=False): for brand in Brand.objects.filter(web_certificate__isnull=False):
objects.append(tenant) objects.append(brand)
objects.append(tenant.web_certificate) objects.append(brand.web_certificate)
return objects return objects
def __str__(self) -> str: def __str__(self) -> str:

View File

@ -5,12 +5,12 @@ from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_sav
from django.dispatch import receiver from django.dispatch import receiver
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.brands.models import Brand
from authentik.core.models import Provider from authentik.core.models import Provider
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection from authentik.outposts.models import Outpost, OutpostServiceConnection
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
UPDATE_TRIGGERING_MODELS = ( UPDATE_TRIGGERING_MODELS = (
@ -18,7 +18,7 @@ UPDATE_TRIGGERING_MODELS = (
OutpostServiceConnection, OutpostServiceConnection,
Provider, Provider,
CertificateKeyPair, CertificateKeyPair,
Tenant, Brand,
) )

View File

@ -6,6 +6,7 @@ from authentik.blueprints.apps import ManagedAppConfig
GAUGE_POLICIES_CACHED = Gauge( GAUGE_POLICIES_CACHED = Gauge(
"authentik_policies_cached", "authentik_policies_cached",
"Cached Policies", "Cached Policies",
["tenant"],
) )
HIST_POLICIES_ENGINE_TOTAL_TIME = Histogram( HIST_POLICIES_ENGINE_TOTAL_TIME = Histogram(
"authentik_policies_engine_time_total_seconds", "authentik_policies_engine_time_total_seconds",
@ -34,6 +35,6 @@ class AuthentikPoliciesConfig(ManagedAppConfig):
verbose_name = "authentik Policies" verbose_name = "authentik Policies"
default = True default = True
def reconcile_load_policies_signals(self): def reconcile_global_load_policies_signals(self):
"""Load policies signals""" """Load policies signals"""
self.import_module("authentik.policies.signals") self.import_module("authentik.policies.signals")

View File

@ -161,7 +161,7 @@ class Migration(migrations.Migration):
("authentik.stages.user_login", "authentik Stages.User Login"), ("authentik.stages.user_login", "authentik Stages.User Login"),
("authentik.stages.user_logout", "authentik Stages.User Logout"), ("authentik.stages.user_logout", "authentik Stages.User Logout"),
("authentik.stages.user_write", "authentik Stages.User Write"), ("authentik.stages.user_write", "authentik Stages.User Write"),
("authentik.tenants", "authentik Tenants"), ("authentik.brands", "authentik Brands"),
("authentik.core", "authentik Core"), ("authentik.core", "authentik Core"),
("authentik.blueprints", "authentik Blueprints"), ("authentik.blueprints", "authentik Blueprints"),
], ],

View File

@ -67,7 +67,7 @@ class Migration(migrations.Migration):
("authentik.stages.user_login", "authentik Stages.User Login"), ("authentik.stages.user_login", "authentik Stages.User Login"),
("authentik.stages.user_logout", "authentik Stages.User Logout"), ("authentik.stages.user_logout", "authentik Stages.User Logout"),
("authentik.stages.user_write", "authentik Stages.User Write"), ("authentik.stages.user_write", "authentik Stages.User Write"),
("authentik.tenants", "authentik Tenants"), ("authentik.brands", "authentik Brands"),
("authentik.blueprints", "authentik Blueprints"), ("authentik.blueprints", "authentik Blueprints"),
("authentik.core", "authentik Core"), ("authentik.core", "authentik Core"),
], ],

View File

@ -143,7 +143,7 @@ class PasswordPolicy(Policy):
user_inputs.append(request.user.name) user_inputs.append(request.user.name)
user_inputs.append(request.user.email) user_inputs.append(request.user.email)
if request.http_request: if request.http_request:
user_inputs.append(request.http_request.tenant.branding_title) user_inputs.append(request.http_request.brand.branding_title)
# Only calculate result for the first 100 characters, as with over 100 char # Only calculate result for the first 100 characters, as with over 100 char
# long passwords we can be reasonably sure that they'll surpass the score anyways # long passwords we can be reasonably sure that they'll surpass the score anyways
# See https://github.com/dropbox/zxcvbn#runtime-latency # See https://github.com/dropbox/zxcvbn#runtime-latency

View File

@ -10,10 +10,10 @@ class AuthentikPolicyReputationConfig(ManagedAppConfig):
verbose_name = "authentik Policies.Reputation" verbose_name = "authentik Policies.Reputation"
default = True default = True
def reconcile_load_policies_reputation_signals(self): def reconcile_global_load_policies_reputation_signals(self):
"""Load policies.reputation signals""" """Load policies.reputation signals"""
self.import_module("authentik.policies.reputation.signals") self.import_module("authentik.policies.reputation.signals")
def reconcile_load_policies_reputation_tasks(self): def reconcile_global_load_policies_reputation_tasks(self):
"""Load policies.reputation tasks""" """Load policies.reputation tasks"""
self.import_module("authentik.policies.reputation.tasks") self.import_module("authentik.policies.reputation.tasks")

View File

@ -1,5 +1,6 @@
"""authentik policy signals""" """authentik policy signals"""
from django.core.cache import cache from django.core.cache import cache
from django.db import connection
from django.db.models.signals import post_save from django.db.models.signals import post_save
from django.dispatch import receiver from django.dispatch import receiver
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -17,7 +18,9 @@ LOGGER = get_logger()
@receiver(monitoring_set) @receiver(monitoring_set)
def monitoring_set_policies(sender, **kwargs): def monitoring_set_policies(sender, **kwargs):
"""set policy gauges""" """set policy gauges"""
GAUGE_POLICIES_CACHED.set(len(cache.keys(f"{CACHE_PREFIX}*") or [])) GAUGE_POLICIES_CACHED.labels(tenant=connection.schema_name).set(
len(cache.keys(f"{CACHE_PREFIX}*") or [])
)
@receiver(post_save, sender=Policy) @receiver(post_save, sender=Policy)

View File

@ -4,7 +4,7 @@
{% load i18n %} {% load i18n %}
{% block title %} {% block title %}
{% trans 'Permission denied' %} - {{ tenant.branding_title }} {% trans 'Permission denied' %} - {{ brand.branding_title }}
{% endblock %} {% endblock %}
{% block card_title %} {% block card_title %}

View File

@ -4,7 +4,7 @@ from urllib.parse import urlencode
from django.urls import reverse from django.urls import reverse
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_tenant from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
from authentik.providers.oauth2.tests.utils import OAuthTestCase from authentik.providers.oauth2.tests.utils import OAuthTestCase
@ -28,9 +28,9 @@ class TesOAuth2DeviceInit(OAuthTestCase):
self.user = create_test_admin_user() self.user = create_test_admin_user()
self.client.force_login(self.user) self.client.force_login(self.user)
self.device_flow = create_test_flow() self.device_flow = create_test_flow()
self.tenant = create_test_tenant() self.brand = create_test_brand()
self.tenant.flow_device_code = self.device_flow self.brand.flow_device_code = self.device_flow
self.tenant.save() self.brand.save()
def test_device_init(self): def test_device_init(self):
"""Test device init""" """Test device init"""
@ -48,8 +48,8 @@ class TesOAuth2DeviceInit(OAuthTestCase):
def test_no_flow(self): def test_no_flow(self):
"""Test no flow""" """Test no flow"""
self.tenant.flow_device_code = None self.brand.flow_device_code = None
self.tenant.save() self.brand.save()
res = self.client.get(reverse("authentik_providers_oauth2_root:device-login")) res = self.client.get(reverse("authentik_providers_oauth2_root:device-login"))
self.assertEqual(res.status_code, 404) self.assertEqual(res.status_code, 404)

View File

@ -8,6 +8,7 @@ from rest_framework.exceptions import ErrorDetail
from rest_framework.fields import CharField, IntegerField from rest_framework.fields import CharField, IntegerField
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.brands.models import Brand
from authentik.core.models import Application from authentik.core.models import Application
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
from authentik.flows.exceptions import FlowNonApplicableException from authentik.flows.exceptions import FlowNonApplicableException
@ -26,7 +27,6 @@ from authentik.stages.consent.stage import (
PLAN_CONTEXT_CONSENT_HEADER, PLAN_CONTEXT_CONSENT_HEADER,
PLAN_CONTEXT_CONSENT_PERMISSIONS, PLAN_CONTEXT_CONSENT_PERMISSIONS,
) )
from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
QS_KEY_CODE = "code" # nosec QS_KEY_CODE = "code" # nosec
@ -88,10 +88,10 @@ class DeviceEntryView(View):
"""View used to initiate the device-code flow, url entered by endusers""" """View used to initiate the device-code flow, url entered by endusers"""
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
tenant: Tenant = request.tenant brand: Brand = request.brand
device_flow = tenant.flow_device_code device_flow = brand.flow_device_code
if not device_flow: if not device_flow:
LOGGER.info("Tenant has no device code flow configured", tenant=tenant) LOGGER.info("Brand has no device code flow configured", brand=brand)
return HttpResponse(status=404) return HttpResponse(status=404)
if QS_KEY_CODE in request.GET: if QS_KEY_CODE in request.GET:
validation = validate_code(request.GET[QS_KEY_CODE], request) validation = validate_code(request.GET[QS_KEY_CODE], request)

View File

@ -97,7 +97,7 @@ class GitHubUserTeamsView(View):
"created_at": "", "created_at": "",
"updated_at": "", "updated_at": "",
"organization": { "organization": {
"login": slugify(request.tenant.branding_title), "login": slugify(request.brand.branding_title),
"id": 1, "id": 1,
"node_id": "", "node_id": "",
"url": "", "url": "",
@ -109,7 +109,7 @@ class GitHubUserTeamsView(View):
"public_members_url": "", "public_members_url": "",
"avatar_url": "", "avatar_url": "",
"description": "", "description": "",
"name": request.tenant.branding_title, "name": request.brand.branding_title,
"company": "", "company": "",
"blog": "", "blog": "",
"location": "", "location": "",

View File

@ -10,6 +10,6 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
verbose_name = "authentik Providers.Proxy" verbose_name = "authentik Providers.Proxy"
default = True default = True
def reconcile_load_providers_proxy_signals(self): def reconcile_global_load_providers_proxy_signals(self):
"""Load proxy signals""" """Load proxy signals"""
self.import_module("authentik.providers.proxy.signals") self.import_module("authentik.providers.proxy.signals")

View File

@ -10,6 +10,6 @@ class AuthentikProviderSCIMConfig(ManagedAppConfig):
verbose_name = "authentik Providers.SCIM" verbose_name = "authentik Providers.SCIM"
default = True default = True
def reconcile_load_signals(self): def reconcile_global_load_signals(self):
"""Load signals""" """Load signals"""
self.import_module("authentik.providers.scim.signals") self.import_module("authentik.providers.scim.signals")

View File

@ -1,20 +1,20 @@
"""SCIM Sync""" """SCIM Sync"""
from django.core.management.base import BaseCommand
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync from authentik.providers.scim.tasks import scim_sync
from authentik.tenants.management import TenantCommand
LOGGER = get_logger() LOGGER = get_logger()
class Command(BaseCommand): class Command(TenantCommand):
"""Run sync for an SCIM Provider""" """Run sync for an SCIM Provider"""
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("providers", nargs="+", type=str) parser.add_argument("providers", nargs="+", type=str)
def handle(self, **options): def handle_per_tenant(self, **options):
for provider_name in options["providers"]: for provider_name in options["providers"]:
provider = SCIMProvider.objects.filter(name=provider_name).first() provider = SCIMProvider.objects.filter(name=provider_name).first()
if not provider: if not provider:

View File

@ -9,6 +9,7 @@ from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync from authentik.providers.scim.tasks import scim_sync
from authentik.tenants.models import Tenant
class SCIMMembershipTests(TestCase): class SCIMMembershipTests(TestCase):
@ -22,6 +23,7 @@ class SCIMMembershipTests(TestCase):
# which will cause errors with multiple users # which will cause errors with multiple users
User.objects.all().exclude(pk=get_anonymous_user().pk).delete() User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
Group.objects.all().delete() Group.objects.all().delete()
Tenant.objects.update(avatars="none")
@apply_blueprint("system/providers-scim.yaml") @apply_blueprint("system/providers-scim.yaml")
def configure(self) -> None: def configure(self) -> None:

View File

@ -11,6 +11,7 @@ from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync from authentik.providers.scim.tasks import scim_sync
from authentik.tenants.models import Tenant
class SCIMUserTests(TestCase): class SCIMUserTests(TestCase):
@ -20,6 +21,7 @@ class SCIMUserTests(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID # Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users # which will cause errors with multiple users
Tenant.objects.update(avatars="none")
User.objects.all().exclude(pk=get_anonymous_user().pk).delete() User.objects.all().exclude(pk=get_anonymous_user().pk).delete()
Group.objects.all().delete() Group.objects.all().delete()
self.provider: SCIMProvider = SCIMProvider.objects.create( self.provider: SCIMProvider = SCIMProvider.objects.create(

View File

@ -10,6 +10,6 @@ class AuthentikRBACConfig(ManagedAppConfig):
verbose_name = "authentik RBAC" verbose_name = "authentik RBAC"
default = True default = True
def reconcile_load_rbac_signals(self): def reconcile_global_load_rbac_signals(self):
"""Load rbac signals""" """Load rbac signals"""
self.import_module("authentik.rbac.signals") self.import_module("authentik.rbac.signals")

View File

@ -0,0 +1,29 @@
# Generated by Django 4.2.8 on 2023-12-20 10:02
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_rbac", "0002_systempermission"),
]
operations = [
migrations.AlterModelOptions(
name="systempermission",
options={
"default_permissions": (),
"managed": False,
"permissions": [
("view_system_info", "Can view system info"),
("view_system_tasks", "Can view system tasks"),
("run_system_tasks", "Can run system tasks"),
("access_admin_interface", "Can access admin interface"),
("view_system_settings", "Can view system settings"),
("edit_system_settings", "Can edit system settings"),
],
"verbose_name": "System permission",
"verbose_name_plural": "System permissions",
},
),
]

View File

@ -70,4 +70,6 @@ class SystemPermission(models.Model):
("view_system_tasks", _("Can view system tasks")), ("view_system_tasks", _("Can view system tasks")),
("run_system_tasks", _("Can run system tasks")), ("run_system_tasks", _("Can run system tasks")),
("access_admin_interface", _("Can access admin interface")), ("access_admin_interface", _("Can access admin interface")),
("view_system_settings", _("Can view system settings")),
("edit_system_settings", _("Can edit system settings")),
] ]

35
authentik/recovery/lib.py Normal file
View File

@ -0,0 +1,35 @@
"""Recovery helper functions."""
from datetime import datetime
from django.urls import reverse
from django.utils.text import slugify
from django.utils.timezone import now
from authentik.core.models import Group, Token, TokenIntents, User
def create_admin_group(user: User) -> Group:
"""Create admin group and add user to it."""
group, _ = Group.objects.update_or_create(
name="authentik Admins",
defaults={
"is_superuser": True,
},
)
group.users.add(user)
return group
def create_recovery_token(user: User, expiry: datetime, generated_from: str) -> (Token, str):
"""Create recovery token and associated link"""
_now = now()
token = Token.objects.create(
expires=expiry,
user=user,
intent=TokenIntents.INTENT_RECOVERY,
description=f"Recovery Token generated by {generated_from} on {_now}",
identifier=slugify(f"ak-recovery-{user}-{_now}"),
)
url = reverse("authentik_recovery:use-token", kwargs={"key": str(token.key)})
return token, url

View File

@ -1,11 +1,12 @@
"""authentik recovery create_admin_group""" """authentik recovery create_admin_group"""
from django.core.management.base import BaseCommand
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from authentik.core.models import Group, User from authentik.core.models import User
from authentik.recovery.lib import create_admin_group
from authentik.tenants.management import TenantCommand
class Command(BaseCommand): class Command(TenantCommand):
"""Create admin group if the default group gets deleted""" """Create admin group if the default group gets deleted"""
help = _("Create admin group if the default group gets deleted.") help = _("Create admin group if the default group gets deleted.")
@ -13,18 +14,12 @@ class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("user", action="store", help="User to add to the admin group.") parser.add_argument("user", action="store", help="User to add to the admin group.")
def handle(self, *args, **options): def handle_per_tenant(self, *args, **options):
"""Create admin group if the default group gets deleted""" """Create admin group if the default group gets deleted"""
username = options.get("user") username = options.get("user")
user = User.objects.filter(username=username).first() user = User.objects.filter(username=username).first()
if not user: if not user:
self.stderr.write(f"User '{username}' not found.") self.stderr.write(f"User '{username}' not found.")
return return
group, _ = Group.objects.update_or_create( group = create_admin_group(user)
name="authentik Admins", self.stdout.write(f"User '{username}' successfully added to the group '{group.name}'.")
defaults={
"is_superuser": True,
},
)
group.users.add(user)
self.stdout.write(f"User '{username}' successfully added to the group 'authentik Admins'.")

View File

@ -2,16 +2,15 @@
from datetime import timedelta from datetime import timedelta
from getpass import getuser from getpass import getuser
from django.core.management.base import BaseCommand
from django.urls import reverse
from django.utils.text import slugify
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from authentik.core.models import Token, TokenIntents, User from authentik.core.models import User
from authentik.recovery.lib import create_recovery_token
from authentik.tenants.management import TenantCommand
class Command(BaseCommand): class Command(TenantCommand):
"""Create Token used to recover access""" """Create Token used to recover access"""
help = _("Create a Key which can be used to restore access to authentik.") help = _("Create a Key which can be used to restore access to authentik.")
@ -25,28 +24,16 @@ class Command(BaseCommand):
) )
parser.add_argument("user", action="store", help="Which user the Token gives access to.") parser.add_argument("user", action="store", help="Which user the Token gives access to.")
def get_url(self, token: Token) -> str: def handle_per_tenant(self, *args, **options):
"""Get full recovery link"""
return reverse("authentik_recovery:use-token", kwargs={"key": str(token.key)})
def handle(self, *args, **options):
"""Create Token used to recover access""" """Create Token used to recover access"""
duration = int(options.get("duration", 1)) duration = int(options.get("duration", 1))
_now = now() expiry = now() + timedelta(days=duration * 365.2425)
expiry = _now + timedelta(days=duration * 365.2425) user = User.objects.filter(username=options.get("user")).first()
users = User.objects.filter(username=options.get("user")) if not user:
if not users.exists():
self.stderr.write(f"User '{options.get('user')}' not found.") self.stderr.write(f"User '{options.get('user')}' not found.")
return return
user = users.first() _, url = create_recovery_token(user, expiry, getuser())
token = Token.objects.create(
expires=expiry,
user=user,
intent=TokenIntents.INTENT_RECOVERY,
description=f"Recovery Token generated by {getuser()} on {_now}",
identifier=slugify(f"ak-recovery-{user}-{_now}"),
)
self.stdout.write( self.stdout.write(
f"Store this link safely, as it will allow anyone to access authentik as {user}." f"Store this link safely, as it will allow anyone to access authentik as {user}."
) )
self.stdout.write(self.get_url(token)) self.stdout.write(url)

View File

@ -4,6 +4,7 @@ from io import StringIO
from django.core.management import call_command from django.core.management import call_command
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from django_tenants.utils import get_public_schema_name
from authentik.core.models import Token, TokenIntents, User from authentik.core.models import Token, TokenIntents, User
@ -18,7 +19,13 @@ class TestRecovery(TestCase):
"""Test creation of a new key""" """Test creation of a new key"""
out = StringIO() out = StringIO()
self.assertEqual(len(Token.objects.all()), 0) self.assertEqual(len(Token.objects.all()), 0)
call_command("create_recovery_key", "1", self.user.username, stdout=out) call_command(
"create_recovery_key",
"1",
self.user.username,
schema=get_public_schema_name(),
stdout=out,
)
token = Token.objects.get(intent=TokenIntents.INTENT_RECOVERY, user=self.user) token = Token.objects.get(intent=TokenIntents.INTENT_RECOVERY, user=self.user)
self.assertIn(token.key, out.getvalue()) self.assertIn(token.key, out.getvalue())
self.assertEqual(len(Token.objects.all()), 1) self.assertEqual(len(Token.objects.all()), 1)
@ -27,13 +34,19 @@ class TestRecovery(TestCase):
"""Test creation of a new key (invalid)""" """Test creation of a new key (invalid)"""
out = StringIO() out = StringIO()
self.assertEqual(len(Token.objects.all()), 0) self.assertEqual(len(Token.objects.all()), 0)
call_command("create_recovery_key", "1", "foo", stderr=out) call_command("create_recovery_key", "1", "foo", schema=get_public_schema_name(), stderr=out)
self.assertIn("not found", out.getvalue()) self.assertIn("not found", out.getvalue())
def test_recovery_view(self): def test_recovery_view(self):
"""Test recovery view""" """Test recovery view"""
out = StringIO() out = StringIO()
call_command("create_recovery_key", "1", self.user.username, stdout=out) call_command(
"create_recovery_key",
"1",
self.user.username,
schema=get_public_schema_name(),
stdout=out,
)
token = Token.objects.get(intent=TokenIntents.INTENT_RECOVERY, user=self.user) token = Token.objects.get(intent=TokenIntents.INTENT_RECOVERY, user=self.user)
self.client.get(reverse("authentik_recovery:use-token", kwargs={"key": token.key})) self.client.get(reverse("authentik_recovery:use-token", kwargs={"key": token.key}))
self.assertEqual(int(self.client.session["_auth_user_id"]), token.user.pk) self.assertEqual(int(self.client.session["_auth_user_id"]), token.user.pk)
@ -46,12 +59,14 @@ class TestRecovery(TestCase):
def test_recovery_admin_group_invalid(self): def test_recovery_admin_group_invalid(self):
"""Test creation of admin group""" """Test creation of admin group"""
out = StringIO() out = StringIO()
call_command("create_admin_group", "1", stderr=out) call_command("create_admin_group", "1", schema=get_public_schema_name(), stderr=out)
self.assertIn("not found", out.getvalue()) self.assertIn("not found", out.getvalue())
def test_recovery_admin_group(self): def test_recovery_admin_group(self):
"""Test creation of admin group""" """Test creation of admin group"""
out = StringIO() out = StringIO()
call_command("create_admin_group", self.user.username, stdout=out) call_command(
"create_admin_group", self.user.username, schema=get_public_schema_name(), stdout=out
)
self.assertIn("successfully added to", out.getvalue()) self.assertIn("successfully added to", out.getvalue())
self.assertTrue(self.user.is_superuser) self.assertTrue(self.user.is_superuser)

View File

@ -6,7 +6,7 @@ from pathlib import Path
from tempfile import gettempdir from tempfile import gettempdir
from typing import Callable from typing import Callable
from celery import Celery, bootsteps from celery import bootsteps
from celery.apps.worker import Worker from celery.apps.worker import Worker
from celery.signals import ( from celery.signals import (
after_task_publish, after_task_publish,
@ -19,8 +19,10 @@ from celery.signals import (
) )
from django.conf import settings from django.conf import settings
from django.db import ProgrammingError from django.db import ProgrammingError
from django_tenants.utils import get_public_schema_name
from structlog.contextvars import STRUCTLOG_KEY_PREFIX from structlog.contextvars import STRUCTLOG_KEY_PREFIX
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
from authentik.lib.sentry import before_send from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
@ -29,7 +31,7 @@ from authentik.lib.utils.errors import exception_to_string
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
LOGGER = get_logger() LOGGER = get_logger()
CELERY_APP = Celery("authentik") CELERY_APP = TenantAwareCeleryApp("authentik")
CTX_TASK_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "task_id", default=Ellipsis) CTX_TASK_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "task_id", default=Ellipsis)
HEARTBEAT_FILE = Path(gettempdir() + "/authentik-worker") HEARTBEAT_FILE = Path(gettempdir() + "/authentik-worker")
@ -80,8 +82,13 @@ def task_error_hook(task_id, exception: Exception, traceback, *args, **kwargs):
Event.new(EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception)).save() Event.new(EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception)).save()
def _get_startup_tasks() -> list[Callable]: def _get_startup_tasks_default_tenant() -> list[Callable]:
"""Get all tasks to be run on startup""" """Get all tasks to be run on startup for the default tenant"""
return []
def _get_startup_tasks_all_tenants() -> list[Callable]:
"""Get all tasks to be run on startup for all tenants"""
from authentik.admin.tasks import clear_update_notifications from authentik.admin.tasks import clear_update_notifications
from authentik.outposts.tasks import outpost_connection_discovery, outpost_controller_all from authentik.outposts.tasks import outpost_connection_discovery, outpost_controller_all
from authentik.providers.proxy.tasks import proxy_set_defaults from authentik.providers.proxy.tasks import proxy_set_defaults
@ -97,13 +104,25 @@ def _get_startup_tasks() -> list[Callable]:
@worker_ready.connect @worker_ready.connect
def worker_ready_hook(*args, **kwargs): def worker_ready_hook(*args, **kwargs):
"""Run certain tasks on worker start""" """Run certain tasks on worker start"""
from authentik.tenants.models import Tenant
LOGGER.info("Dispatching startup tasks...") LOGGER.info("Dispatching startup tasks...")
for task in _get_startup_tasks():
def _run_task(task: Callable):
try: try:
task.delay() task.delay()
except ProgrammingError as exc: except ProgrammingError as exc:
LOGGER.warning("Startup task failed", task=task, exc=exc) LOGGER.warning("Startup task failed", task=task, exc=exc)
for task in _get_startup_tasks_default_tenant():
with Tenant.objects.get(schema_name=get_public_schema_name()):
_run_task(task)
for task in _get_startup_tasks_all_tenants():
for tenant in Tenant.objects.filter(ready=True):
with tenant:
_run_task(task)
from authentik.blueprints.v1.tasks import start_blueprint_watcher from authentik.blueprints.v1.tasks import start_blueprint_watcher
start_blueprint_watcher() start_blueprint_watcher()

View File

@ -1,5 +1,5 @@
"""authentik database backend""" """authentik database backend"""
from django_prometheus.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -8,6 +8,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials""" """database backend which supports rotating credentials"""
def get_connection_params(self): def get_connection_params(self):
"""Refresh DB credentials before getting connection params"""
CONFIG.refresh("postgresql.password") CONFIG.refresh("postgresql.password")
conn_params = super().get_connection_params() conn_params = super().get_connection_params()
conn_params["user"] = CONFIG.get("postgresql.user") conn_params["user"] = CONFIG.get("postgresql.user")

View File

@ -17,7 +17,7 @@ def get_install_id() -> str:
if settings.TEST: if settings.TEST:
return str(uuid4()) return str(uuid4())
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;") cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;")
return cursor.fetchone()[0] return cursor.fetchone()[0]
@ -37,5 +37,5 @@ def get_install_id_raw():
sslkey=CONFIG.get("postgresql.sslkey"), sslkey=CONFIG.get("postgresql.sslkey"),
) )
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;") cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;")
return cursor.fetchone()[0] return cursor.fetchone()[0]

View File

@ -1,6 +1,7 @@
"""root settings for authentik""" """root settings for authentik"""
import importlib import importlib
import os import os
from collections import OrderedDict
from hashlib import sha512 from hashlib import sha512
from pathlib import Path from pathlib import Path
from urllib.parse import quote_plus from urllib.parse import quote_plus
@ -16,8 +17,6 @@ from authentik.lib.utils.reflection import get_env
from authentik.stages.password import BACKEND_APP_PASSWORD, BACKEND_INBUILT, BACKEND_LDAP from authentik.stages.password import BACKEND_APP_PASSWORD, BACKEND_INBUILT, BACKEND_LDAP
BASE_DIR = Path(__file__).absolute().parent.parent.parent BASE_DIR = Path(__file__).absolute().parent.parent.parent
STATICFILES_DIRS = [BASE_DIR / Path("web")]
MEDIA_ROOT = BASE_DIR / Path("media")
DEBUG = CONFIG.get_bool("debug") DEBUG = CONFIG.get_bool("debug")
SECRET_KEY = CONFIG.get("secret_key") SECRET_KEY = CONFIG.get("secret_key")
@ -49,14 +48,23 @@ AUTHENTICATION_BACKENDS = [
DEFAULT_AUTO_FIELD = "django.db.models.AutoField" DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
# Application definition # Application definition
INSTALLED_APPS = [ SHARED_APPS = [
"django_tenants",
"authentik.tenants",
"daphne", "daphne",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.messages", "django.contrib.messages",
"django.contrib.staticfiles", "django.contrib.staticfiles",
"django.contrib.humanize", "django.contrib.humanize",
"rest_framework",
"django_filters",
"drf_spectacular",
"django_prometheus",
"channels",
]
TENANT_APPS = [
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"authentik.admin", "authentik.admin",
"authentik.api", "authentik.api",
"authentik.crypto", "authentik.crypto",
@ -102,16 +110,17 @@ INSTALLED_APPS = [
"authentik.stages.user_login", "authentik.stages.user_login",
"authentik.stages.user_logout", "authentik.stages.user_logout",
"authentik.stages.user_write", "authentik.stages.user_write",
"authentik.tenants", "authentik.brands",
"authentik.blueprints", "authentik.blueprints",
"rest_framework",
"django_filters",
"drf_spectacular",
"guardian", "guardian",
"django_prometheus",
"channels",
] ]
TENANT_MODEL = "authentik_tenants.Tenant"
TENANT_DOMAIN_MODEL = "authentik_tenants.Domain"
TENANT_CREATION_FAKES_MIGRATIONS = True
TENANT_BASE_SCHEMA = "template"
GUARDIAN_MONKEY_PATCH = False GUARDIAN_MONKEY_PATCH = False
SPECTACULAR_SETTINGS = { SPECTACULAR_SETTINGS = {
@ -199,6 +208,8 @@ CACHES = {
"TIMEOUT": CONFIG.get_int("cache.timeout", 300), "TIMEOUT": CONFIG.get_int("cache.timeout", 300),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache", "KEY_PREFIX": "authentik_cache",
"KEY_FUNCTION": "django_tenants.cache.make_key",
"REVERSE_KEY_FUNCTION": "django_tenants.cache.reverse_key",
} }
} }
DJANGO_REDIS_SCAN_ITERSIZE = 1000 DJANGO_REDIS_SCAN_ITERSIZE = 1000
@ -215,13 +226,14 @@ SESSION_EXPIRE_AT_BROWSER_CLOSE = True
MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage" MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
MIDDLEWARE = [ MIDDLEWARE = [
"django_tenants.middleware.default.DefaultTenantMiddleware",
"authentik.root.middleware.LoggingMiddleware", "authentik.root.middleware.LoggingMiddleware",
"django_prometheus.middleware.PrometheusBeforeMiddleware", "django_prometheus.middleware.PrometheusBeforeMiddleware",
"authentik.root.middleware.ClientIPMiddleware", "authentik.root.middleware.ClientIPMiddleware",
"authentik.stages.user_login.middleware.BoundSessionMiddleware", "authentik.stages.user_login.middleware.BoundSessionMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware",
"authentik.core.middleware.RequestIDMiddleware", "authentik.core.middleware.RequestIDMiddleware",
"authentik.tenants.middleware.TenantMiddleware", "authentik.brands.middleware.BrandMiddleware",
"authentik.events.middleware.AuditMiddleware", "authentik.events.middleware.AuditMiddleware",
"django.middleware.security.SecurityMiddleware", "django.middleware.security.SecurityMiddleware",
"django.middleware.common.CommonMiddleware", "django.middleware.common.CommonMiddleware",
@ -245,7 +257,7 @@ TEMPLATES = [
"django.template.context_processors.request", "django.template.context_processors.request",
"django.contrib.auth.context_processors.auth", "django.contrib.auth.context_processors.auth",
"django.contrib.messages.context_processors.messages", "django.contrib.messages.context_processors.messages",
"authentik.tenants.utils.context_processor", "authentik.brands.utils.context_processor",
], ],
}, },
}, },
@ -267,6 +279,7 @@ CHANNEL_LAYERS = {
# Database # Database
# https://docs.djangoproject.com/en/2.1/ref/settings/#databases # https://docs.djangoproject.com/en/2.1/ref/settings/#databases
ORIGINAL_BACKEND = "django_prometheus.db.backends.postgresql"
DATABASES = { DATABASES = {
"default": { "default": {
"ENGINE": "authentik.root.db", "ENGINE": "authentik.root.db",
@ -294,6 +307,8 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections # https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
DATABASE_ROUTERS = ("django_tenants.routers.TenantSyncRouter",)
# Email # Email
# These values should never actually be used, emails are only sent from email stages, which # These values should never actually be used, emails are only sent from email stages, which
# loads the config directly from CONFIG # loads the config directly from CONFIG
@ -351,6 +366,7 @@ CELERY = {
"options": {"queue": "authentik_scheduled"}, "options": {"queue": "authentik_scheduled"},
}, },
}, },
"beat_scheduler": "authentik.tenants.scheduler:TenantAwarePersistentScheduler",
"task_create_missing_queues": True, "task_create_missing_queues": True,
"task_default_queue": "authentik", "task_default_queue": "authentik",
"broker_url": CONFIG.get("broker.url") "broker_url": CONFIG.get("broker.url")
@ -372,8 +388,54 @@ if _ERROR_REPORTING:
# Static files (CSS, JavaScript, Images) # Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.1/howto/static-files/ # https://docs.djangoproject.com/en/2.1/howto/static-files/
STATICFILES_DIRS = [BASE_DIR / Path("web")]
STATIC_URL = "/static/" STATIC_URL = "/static/"
MEDIA_URL = "/media/"
STORAGES = {
"staticfiles": {
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage",
},
}
# Media files
if CONFIG.get("storage.media.backend", "file") == "s3":
STORAGES["default"] = {
"BACKEND": "authentik.root.storages.S3Storage",
"OPTIONS": {
# How to talk to S3
"session_profile": CONFIG.get("storage.media.s3.session_profile", None),
"access_key": CONFIG.get("storage.media.s3.access_key", None),
"secret_key": CONFIG.get("storage.media.s3.secret_key", None),
"security_token": CONFIG.get("storage.media.s3.security_token", None),
"region_name": CONFIG.get("storage.media.s3.region", None),
"use_ssl": CONFIG.get_bool("storage.media.s3.use_ssl", True),
"endpoint_url": CONFIG.get("storage.media.s3.endpoint", None),
"bucket_name": CONFIG.get("storage.media.s3.bucket_name"),
"default_acl": "private",
"querystring_auth": True,
"signature_version": "s3v4",
"file_overwrite": False,
"location": "media",
"url_protocol": "https:"
if CONFIG.get("storage.media.s3.secure_urls", True)
else "http:",
"custom_domain": CONFIG.get("storage.media.s3.custom_domain", None),
},
}
# Fallback on file storage backend
else:
STORAGES["default"] = {
"BACKEND": "authentik.root.storages.FileStorage",
"OPTIONS": {
"location": Path(CONFIG.get("storage.media.file.path")),
"base_url": "/media/",
},
}
# Compatibility for apps not supporting top-level STORAGES
# such as django-tenants
MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"]
MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"]
TEST = False TEST = False
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
@ -383,6 +445,8 @@ LOGGING = get_logger_config()
_DISALLOWED_ITEMS = [ _DISALLOWED_ITEMS = [
"SHARED_APPS",
"TENANT_APPS",
"INSTALLED_APPS", "INSTALLED_APPS",
"MIDDLEWARE", "MIDDLEWARE",
"AUTHENTICATION_BACKENDS", "AUTHENTICATION_BACKENDS",
@ -394,7 +458,8 @@ def _update_settings(app_path: str):
try: try:
settings_module = importlib.import_module(app_path) settings_module = importlib.import_module(app_path)
CONFIG.log("debug", "Loaded app settings", path=app_path) CONFIG.log("debug", "Loaded app settings", path=app_path)
INSTALLED_APPS.extend(getattr(settings_module, "INSTALLED_APPS", [])) SHARED_APPS.extend(getattr(settings_module, "SHARED_APPS", []))
TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", []))
MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
@ -406,7 +471,7 @@ def _update_settings(app_path: str):
# Load subapps's settings # Load subapps's settings
for _app in INSTALLED_APPS: for _app in set(SHARED_APPS + TENANT_APPS):
if not _app.startswith("authentik"): if not _app.startswith("authentik"):
continue continue
_update_settings(f"{_app}.settings") _update_settings(f"{_app}.settings")
@ -419,7 +484,7 @@ if DEBUG:
"rest_framework.renderers.BrowsableAPIRenderer" "rest_framework.renderers.BrowsableAPIRenderer"
) )
INSTALLED_APPS.append("authentik.core") TENANT_APPS.append("authentik.core")
CONFIG.log("info", "Booting authentik", version=__version__) CONFIG.log("info", "Booting authentik", version=__version__)
@ -427,7 +492,10 @@ CONFIG.log("info", "Booting authentik", version=__version__)
try: try:
importlib.import_module("authentik.enterprise.apps") importlib.import_module("authentik.enterprise.apps")
CONFIG.log("info", "Enabled authentik enterprise") CONFIG.log("info", "Enabled authentik enterprise")
INSTALLED_APPS.append("authentik.enterprise") TENANT_APPS.append("authentik.enterprise")
_update_settings("authentik.enterprise.settings") _update_settings("authentik.enterprise.settings")
except ImportError: except ImportError:
pass pass
SHARED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS))
INSTALLED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS))

116
authentik/root/storages.py Normal file
View File

@ -0,0 +1,116 @@
"""authentik storage backends"""
import os
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.core.files.storage import FileSystemStorage
from django.db import connection
from storages.backends.s3 import S3Storage as BaseS3Storage
from storages.utils import clean_name, safe_join
from authentik.lib.config import CONFIG
class FileStorage(FileSystemStorage):
"""File storage backend"""
# pylint: disable=invalid-overridden-method
@property
def base_location(self):
return os.path.join(
self._value_or_setting(self._location, settings.MEDIA_ROOT), connection.schema_name
)
# pylint: disable=invalid-overridden-method
@property
def location(self):
return os.path.abspath(self.base_location)
# pylint: disable=invalid-overridden-method
@property
def base_url(self):
if self._base_url is not None and not self._base_url.endswith("/"):
self._base_url += "/"
return f"{self._base_url}/{connection.schema_name}/"
# pylint: disable=abstract-method
class S3Storage(BaseS3Storage):
"""S3 storage backend"""
@property
def session_profile(self) -> str | None:
"""Get session profile"""
return CONFIG.refresh("storage.media.s3.session_profile", None)
@session_profile.setter
def session_profile(self, value: str):
pass
@property
def access_key(self) -> str | None:
"""Get access key"""
return CONFIG.refresh("storage.media.s3.access_key", None)
@access_key.setter
def access_key(self, value: str):
pass
@property
def secret_key(self) -> str | None:
"""Get secret key"""
return CONFIG.refresh("storage.media.s3.secret_key", None)
@secret_key.setter
def secret_key(self, value: str):
pass
@property
def security_token(self) -> str | None:
"""Get security token"""
return CONFIG.refresh("storage.media.s3.security_token", None)
@security_token.setter
def security_token(self, value: str):
pass
def _normalize_name(self, name):
try:
# pylint: disable=no-member
return safe_join(self.location, connection.schema_name, name)
except ValueError:
raise SuspiciousOperation("Attempted access to '%s' denied." % name)
# This is a fix for https://github.com/jschneier/django-storages/pull/839
# pylint: disable=arguments-differ,no-member
def url(self, name, parameters=None, expire=None, http_method=None):
# Preserve the trailing slash after normalizing the path.
name = self._normalize_name(clean_name(name))
params = parameters.copy() if parameters else {}
if expire is None:
expire = self.querystring_expire
params["Bucket"] = self.bucket.name
params["Key"] = name
url = self.bucket.meta.client.generate_presigned_url(
"get_object",
Params=params,
ExpiresIn=expire,
HttpMethod=http_method,
)
if self.custom_domain:
# Key parameter can't be empty. Use "/" and remove it later.
params["Key"] = "/"
root_url_signed = self.bucket.meta.client.generate_presigned_url(
"get_object", Params=params, ExpiresIn=expire
)
# Remove signing parameter and previously added key "/".
root_url = self._strip_signing_parameters(root_url_signed)[:-1]
# Replace bucket domain with custom domain.
custom_url = "{}//{}/".format(self.url_protocol, self.custom_domain)
url = url.replace(root_url, custom_url)
if self.querystring_auth:
return url
return self._strip_signing_parameters(url)

View File

@ -31,7 +31,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
settings.TEST = True settings.TEST = True
settings.CELERY["task_always_eager"] = True settings.CELERY["task_always_eager"] = True
CONFIG.set("avatars", "none")
CONFIG.set("events.context_processors.geoip", "tests/GeoLite2-City-Test.mmdb") CONFIG.set("events.context_processors.geoip", "tests/GeoLite2-City-Test.mmdb")
CONFIG.set("events.context_processors.asn", "tests/GeoLite2-ASN-Test.mmdb") CONFIG.set("events.context_processors.asn", "tests/GeoLite2-ASN-Test.mmdb")
CONFIG.set("blueprints_dir", "./blueprints") CONFIG.set("blueprints_dir", "./blueprints")
@ -39,6 +38,8 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
"outposts.container_image_base", "outposts.container_image_base",
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}", f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
) )
CONFIG.set("tenants.enabled", False)
CONFIG.set("outposts.disable_embedded_outpost", False)
CONFIG.set("error_reporting.sample_rate", 0) CONFIG.set("error_reporting.sample_rate", 0)
CONFIG.set("error_reporting.environment", "testing") CONFIG.set("error_reporting.environment", "testing")
CONFIG.set("error_reporting.send_pii", True) CONFIG.set("error_reporting.send_pii", True)

View File

@ -10,6 +10,6 @@ class AuthentikSourceLDAPConfig(ManagedAppConfig):
verbose_name = "authentik Sources.LDAP" verbose_name = "authentik Sources.LDAP"
default = True default = True
def reconcile_load_sources_ldap_signals(self): def reconcile_global_load_sources_ldap_signals(self):
"""Load sources.ldap signals""" """Load sources.ldap signals"""
self.import_module("authentik.sources.ldap.signals") self.import_module("authentik.sources.ldap.signals")

View File

@ -1,21 +1,21 @@
"""LDAP Connection check""" """LDAP Connection check"""
from json import dumps from json import dumps
from django.core.management.base import BaseCommand
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.models import LDAPSource
from authentik.tenants.management import TenantCommand
LOGGER = get_logger() LOGGER = get_logger()
class Command(BaseCommand): class Command(TenantCommand):
"""Check connectivity to LDAP servers for a source""" """Check connectivity to LDAP servers for a source"""
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("source_slugs", nargs="?", type=str) parser.add_argument("source_slugs", nargs="?", type=str)
def handle(self, **options): def handle_per_tenant(self, **options):
sources = LDAPSource.objects.filter(enabled=True) sources = LDAPSource.objects.filter(enabled=True)
if options["source_slugs"]: if options["source_slugs"]:
sources = LDAPSource.objects.filter(slug__in=options["source_slugs"]) sources = LDAPSource.objects.filter(slug__in=options["source_slugs"])

View File

@ -1,5 +1,4 @@
"""LDAP Sync""" """LDAP Sync"""
from django.core.management.base import BaseCommand
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.models import LDAPSource
@ -7,17 +6,18 @@ from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tasks import ldap_sync_paginator from authentik.sources.ldap.tasks import ldap_sync_paginator
from authentik.tenants.management import TenantCommand
LOGGER = get_logger() LOGGER = get_logger()
class Command(BaseCommand): class Command(TenantCommand):
"""Run sync for an LDAP Source""" """Run sync for an LDAP Source"""
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("source_slugs", nargs="+", type=str) parser.add_argument("source_slugs", nargs="+", type=str)
def handle(self, **options): def handle_per_tenant(self, **options):
for source_slug in options["source_slugs"]: for source_slug in options["source_slugs"]:
source = LDAPSource.objects.filter(slug=source_slug).first() source = LDAPSource.objects.filter(slug=source_slug).first()
if not source: if not source:

Some files were not shown because too many files have changed in this diff Show More