root: use channel send workaround for sync sending of websocket messages
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		@ -7,7 +7,6 @@ from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
import yaml
 | 
			
		||||
from asgiref.sync import async_to_sync
 | 
			
		||||
from channels.layers import get_channel_layer
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.db import DatabaseError, InternalError, ProgrammingError
 | 
			
		||||
from django.db.models.base import Model
 | 
			
		||||
@ -43,6 +42,7 @@ from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesContro
 | 
			
		||||
from authentik.providers.proxy.controllers.docker import ProxyDockerController
 | 
			
		||||
from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
from authentik.root.messages.storage import closing_send
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s"
 | 
			
		||||
@ -217,26 +217,23 @@ def outpost_post_save(model_class: str, model_pk: Any):
 | 
			
		||||
def outpost_send_update(model_instace: Model):
 | 
			
		||||
    """Send outpost update to all registered outposts, regardless to which authentik
 | 
			
		||||
    instance they are connected"""
 | 
			
		||||
    channel_layer = get_channel_layer()
 | 
			
		||||
    if isinstance(model_instace, OutpostModel):
 | 
			
		||||
        for outpost in model_instace.outpost_set.all():
 | 
			
		||||
            _outpost_single_update(outpost, channel_layer)
 | 
			
		||||
            _outpost_single_update(outpost)
 | 
			
		||||
    elif isinstance(model_instace, Outpost):
 | 
			
		||||
        _outpost_single_update(model_instace, channel_layer)
 | 
			
		||||
        _outpost_single_update(model_instace)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _outpost_single_update(outpost: Outpost, layer=None):
 | 
			
		||||
def _outpost_single_update(outpost: Outpost):
 | 
			
		||||
    """Update outpost instances connected to a single outpost"""
 | 
			
		||||
    # Ensure token again, because this function is called when anything related to an
 | 
			
		||||
    # OutpostModel is saved, so we can be sure permissions are right
 | 
			
		||||
    _ = outpost.token
 | 
			
		||||
    outpost.build_user_permissions(outpost.user)
 | 
			
		||||
    if not layer:  # pragma: no cover
 | 
			
		||||
        layer = get_channel_layer()
 | 
			
		||||
    for state in OutpostState.for_outpost(outpost):
 | 
			
		||||
        for channel in state.channel_ids:
 | 
			
		||||
            LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
 | 
			
		||||
            async_to_sync(layer.send)(channel, {"type": "event.update"})
 | 
			
		||||
            async_to_sync(closing_send)(channel, {"type": "event.update"})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task()
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
"""Channels Messages storage"""
 | 
			
		||||
from asgiref.sync import async_to_sync
 | 
			
		||||
from channels.layers import get_channel_layer
 | 
			
		||||
from channels import DEFAULT_CHANNEL_LAYER
 | 
			
		||||
from channels.layers import channel_layers
 | 
			
		||||
from django.contrib.messages.storage.base import Message
 | 
			
		||||
from django.contrib.messages.storage.session import SessionStorage
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
@ -10,13 +11,21 @@ SESSION_KEY = "_messages"
 | 
			
		||||
CACHE_PREFIX = "goauthentik.io/root/messages_"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def closing_send(channel, message):
 | 
			
		||||
    """Wrapper around layer send that closes the connection"""
 | 
			
		||||
    # See https://github.com/django/channels_redis/issues/332
 | 
			
		||||
    # TODO: Remove this after channels_redis 4.1 is released
 | 
			
		||||
    channel_layer = channel_layers.make_backend(DEFAULT_CHANNEL_LAYER)
 | 
			
		||||
    await channel_layer.send(channel, message)
 | 
			
		||||
    await channel_layer.close_pools()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChannelsStorage(SessionStorage):
 | 
			
		||||
    """Send contrib.messages over websocket"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, request: HttpRequest) -> None:
 | 
			
		||||
        # pyright: reportGeneralTypeIssues=false
 | 
			
		||||
        super().__init__(request)
 | 
			
		||||
        self.channel = get_channel_layer()
 | 
			
		||||
 | 
			
		||||
    def _store(self, messages: list[Message], response, *args, **kwargs):
 | 
			
		||||
        prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_"
 | 
			
		||||
@ -28,7 +37,7 @@ class ChannelsStorage(SessionStorage):
 | 
			
		||||
        for key in keys:
 | 
			
		||||
            uid = key.replace(prefix, "")
 | 
			
		||||
            for message in messages:
 | 
			
		||||
                async_to_sync(self.channel.send)(
 | 
			
		||||
                async_to_sync(closing_send)(
 | 
			
		||||
                    uid,
 | 
			
		||||
                    {
 | 
			
		||||
                        "type": "event.update",
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,5 @@
 | 
			
		||||
[tool.pyright]
 | 
			
		||||
ignore = [
 | 
			
		||||
  "**/migrations/**",
 | 
			
		||||
  "**/node_modules/**"
 | 
			
		||||
]
 | 
			
		||||
ignore = ["**/migrations/**", "**/node_modules/**"]
 | 
			
		||||
reportMissingTypeStubs = false
 | 
			
		||||
strictParameterNoneValue = true
 | 
			
		||||
strictDictionaryInference = true
 | 
			
		||||
@ -63,14 +60,7 @@ exclude_lines = [
 | 
			
		||||
show_missing = true
 | 
			
		||||
 | 
			
		||||
[tool.pylint.basic]
 | 
			
		||||
good-names = [
 | 
			
		||||
  "pk",
 | 
			
		||||
  "id",
 | 
			
		||||
  "i",
 | 
			
		||||
  "j",
 | 
			
		||||
  "k",
 | 
			
		||||
  "_",
 | 
			
		||||
]
 | 
			
		||||
good-names = ["pk", "id", "i", "j", "k", "_"]
 | 
			
		||||
 | 
			
		||||
[tool.pylint.master]
 | 
			
		||||
disable = [
 | 
			
		||||
@ -85,6 +75,7 @@ disable = [
 | 
			
		||||
  "protected-access",
 | 
			
		||||
  "unused-argument",
 | 
			
		||||
  "raise-missing-from",
 | 
			
		||||
  "fixme",
 | 
			
		||||
  # To preserve django's translation function we need to use %-formatting
 | 
			
		||||
  "consider-using-f-string",
 | 
			
		||||
]
 | 
			
		||||
@ -120,7 +111,7 @@ authors = ["authentik Team <hello@goauthentik.io>"]
 | 
			
		||||
 | 
			
		||||
[tool.poetry.dependencies]
 | 
			
		||||
celery = "*"
 | 
			
		||||
channels = {version = "*", extras = ["daphne"]}
 | 
			
		||||
channels = { version = "*", extras = ["daphne"] }
 | 
			
		||||
channels-redis = "*"
 | 
			
		||||
codespell = "*"
 | 
			
		||||
colorama = "*"
 | 
			
		||||
@ -147,7 +138,7 @@ gunicorn = "*"
 | 
			
		||||
kubernetes = "*"
 | 
			
		||||
ldap3 = "*"
 | 
			
		||||
lxml = "*"
 | 
			
		||||
opencontainers = {extras = ["reggie"],version = "*"}
 | 
			
		||||
opencontainers = { extras = ["reggie"], version = "*" }
 | 
			
		||||
packaging = "*"
 | 
			
		||||
paramiko = "*"
 | 
			
		||||
psycopg2-binary = "*"
 | 
			
		||||
@ -163,8 +154,8 @@ swagger-spec-validator = "*"
 | 
			
		||||
twilio = "*"
 | 
			
		||||
twisted = "*"
 | 
			
		||||
ua-parser = "*"
 | 
			
		||||
urllib3 = {extras = ["secure"],version = "*"}
 | 
			
		||||
uvicorn = {extras = ["standard"],version = "*"}
 | 
			
		||||
urllib3 = { extras = ["secure"], version = "*" }
 | 
			
		||||
uvicorn = { extras = ["standard"], version = "*" }
 | 
			
		||||
webauthn = "*"
 | 
			
		||||
wsproto = "*"
 | 
			
		||||
xmlsec = "*"
 | 
			
		||||
@ -176,7 +167,7 @@ bandit = "*"
 | 
			
		||||
black = "*"
 | 
			
		||||
bump2version = "*"
 | 
			
		||||
colorama = "*"
 | 
			
		||||
coverage = {extras = ["toml"],version = "*"}
 | 
			
		||||
coverage = { extras = ["toml"], version = "*" }
 | 
			
		||||
importlib-metadata = "*"
 | 
			
		||||
pylint = "*"
 | 
			
		||||
pylint-django = "*"
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user