Files
authentik/authentik/root/test_runner.py
Jens L. f76becfd86 stages/user_login: fix session binding logging (#15175)
* add tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix logging

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* update test db?

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* ah there we go; fix mmdb not being reloaded with test settings

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-21 00:21:49 +02:00

180 lines
6.6 KiB
Python

"""Integrate ./manage.py test with pytest"""
import os
from argparse import ArgumentParser
from unittest import TestCase
from unittest.mock import patch
import pytest
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.test.runner import DiscoverRunner
from structlog.stdlib import get_logger
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.lib.config import CONFIG
from authentik.lib.sentry import sentry_init
from authentik.root.signals import post_startup, pre_startup, startup
# globally set maxDiff to none to show full assert error
TestCase.maxDiff = None
def get_docker_tag() -> str:
"""Get docker-tag based off of CI variables"""
env_pr_branch = "GITHUB_HEAD_REF"
default_branch = "GITHUB_REF"
branch_name = os.environ.get(default_branch, "main")
if os.environ.get(env_pr_branch, "") != "":
branch_name = os.environ[env_pr_branch]
branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
return f"gh-{branch_name}"
def patched__get_ct_cached(app_label, codename):
"""Caches `ContentType` instances like its `QuerySet` does."""
return ContentType.objects.get(app_label=app_label, permission__codename=codename)
class PytestTestRunner(DiscoverRunner): # pragma: no cover
"""Runs pytest to discover and run tests."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.logger = get_logger().bind(runner="pytest")
self.args = []
if self.failfast:
self.args.append("--exitfirst")
if self.keepdb:
self.args.append("--reuse-db")
if kwargs.get("randomly_seed", None):
self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
if kwargs.get("no_capture", False):
self.args.append("--capture=no")
self._setup_test_environment()
def _setup_test_environment(self):
"""Configure test environment settings"""
settings.TEST = True
settings.CELERY["task_always_eager"] = True
# Test-specific configuration
test_config = {
"events.context_processors.geoip": "tests/GeoLite2-City-Test.mmdb",
"events.context_processors.asn": "tests/GeoLite2-ASN-Test.mmdb",
"blueprints_dir": "./blueprints",
"outposts.container_image_base": f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
"tenants.enabled": False,
"outposts.disable_embedded_outpost": False,
"error_reporting.sample_rate": 0,
"error_reporting.environment": "testing",
"error_reporting.send_pii": True,
}
for key, value in test_config.items():
CONFIG.set(key, value)
ASN_CONTEXT_PROCESSOR.load()
GEOIP_CONTEXT_PROCESSOR.load()
sentry_init()
self.logger.debug("Test environment configured")
# Send startup signals
pre_startup.send(sender=self, mode="test")
startup.send(sender=self, mode="test")
post_startup.send(sender=self, mode="test")
@classmethod
def add_arguments(cls, parser: ArgumentParser):
"""Add more pytest-specific arguments"""
DiscoverRunner.add_arguments(parser)
parser.add_argument(
"--randomly-seed",
type=int,
help="Set the seed that pytest-randomly uses (int), or pass the special value 'last'"
"to reuse the seed from the previous run."
"Default behaviour: use random.Random().getrandbits(32), so the seed is"
"different on each run.",
)
parser.add_argument(
"--no-capture",
action="store_true",
help="Disable any capturing of stdout/stderr during tests.",
)
def _validate_test_label(self, label: str) -> bool:
"""Validate test label format"""
if not label:
return False
# Check for invalid characters, but allow forward slashes and colons
# for paths and pytest markers
invalid_chars = set('\\*?"<>|')
if any(c in label for c in invalid_chars):
self.logger.error("Invalid characters in test label", label=label)
return False
return True
def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
"""Run pytest and return the exitcode.
It translates some of Django's test command option to pytest's.
It is supported to only run specific classes and methods using
a dotted module name i.e. foo.bar[.Class[.method]]
The extra_tests argument has been deprecated since Django 5.x
It is kept for compatibility with PyCharm's Django test runner.
"""
if not test_labels:
self.logger.error("No test files specified")
return 1
for label in test_labels:
if not self._validate_test_label(label):
return 1
valid_label_found = False
label_as_path = os.path.abspath(label)
# File path has been specified
if os.path.exists(label_as_path):
self.args.append(label_as_path)
valid_label_found = True
elif "::" in label:
self.args.append(label)
valid_label_found = True
else:
# Check if the label is a dotted module path
path_pieces = label.split(".")
for i in range(-1, -3, -1):
try:
path = os.path.join(*path_pieces[:i]) + ".py"
if os.path.exists(path):
if i < -1:
path_method = path + "::" + "::".join(path_pieces[i:])
self.args.append(path_method)
else:
self.args.append(path)
valid_label_found = True
break
except (TypeError, IndexError):
continue
if not valid_label_found:
self.logger.error("Test file not found", label=label)
return 1
self.logger.info("Running tests", test_files=self.args)
with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
try:
return pytest.main(self.args)
except Exception as e:
self.logger.error("Error running tests", error=str(e), test_files=self.args)
return 1