| 
							
							
							
						 |  |  | @ -1,30 +1,29 @@ | 
		
	
		
			
				|  |  |  |  | import logging | 
		
	
		
			
				|  |  |  |  | from psycopg.errors import AdminShutdown | 
		
	
		
			
				|  |  |  |  | import tenacity | 
		
	
		
			
				|  |  |  |  | import functools | 
		
	
		
			
				|  |  |  |  | from pglock.core import _cast_lock_id | 
		
	
		
			
				|  |  |  |  | from django.utils import timezone | 
		
	
		
			
				|  |  |  |  | from random import randint | 
		
	
		
			
				|  |  |  |  | import logging | 
		
	
		
			
				|  |  |  |  | import time | 
		
	
		
			
				|  |  |  |  | from django.db.backends.postgresql.base import DatabaseWrapper | 
		
	
		
			
				|  |  |  |  | from typing import Iterable | 
		
	
		
			
				|  |  |  |  | from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections | 
		
	
		
			
				|  |  |  |  | from queue import Queue, Empty | 
		
	
		
			
				|  |  |  |  | from collections.abc import Iterable | 
		
	
		
			
				|  |  |  |  | from queue import Empty, Queue | 
		
	
		
			
				|  |  |  |  | from random import randint | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | from django.db.models import QuerySet | 
		
	
		
			
				|  |  |  |  | from dramatiq.broker import Broker, Consumer, MessageProxy | 
		
	
		
			
				|  |  |  |  | from dramatiq.message import Message | 
		
	
		
			
				|  |  |  |  | from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name | 
		
	
		
			
				|  |  |  |  | from dramatiq.results import Results | 
		
	
		
			
				|  |  |  |  | from dramatiq.errors import QueueJoinTimeout, ConnectionError | 
		
	
		
			
				|  |  |  |  | import orjson | 
		
	
		
			
				|  |  |  |  | import tenacity | 
		
	
		
			
				|  |  |  |  | from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections | 
		
	
		
			
				|  |  |  |  | from django.db.backends.postgresql.base import DatabaseWrapper | 
		
	
		
			
				|  |  |  |  | from django.db.models import QuerySet | 
		
	
		
			
				|  |  |  |  | from django.utils import timezone | 
		
	
		
			
				|  |  |  |  | from dramatiq.broker import Broker, Consumer, MessageProxy | 
		
	
		
			
				|  |  |  |  | from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name | 
		
	
		
			
				|  |  |  |  | from dramatiq.errors import ConnectionError, QueueJoinTimeout | 
		
	
		
			
				|  |  |  |  | from dramatiq.message import Message | 
		
	
		
			
				|  |  |  |  | from dramatiq.results import Results | 
		
	
		
			
				|  |  |  |  | from pglock.core import _cast_lock_id | 
		
	
		
			
				|  |  |  |  | from psycopg import Notify | 
		
	
		
			
				|  |  |  |  | from psycopg.errors import AdminShutdown | 
		
	
		
			
				|  |  |  |  | from structlog.stdlib import get_logger | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | from authentik.tasks.models import Queue | 
		
	
		
			
				|  |  |  |  | from authentik.tasks.models import Queue as MQueue | 
		
	
		
			
				|  |  |  |  | from authentik.tasks.results import PostgresBackend | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | LOGGER = get_logger() | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @ -61,7 +60,7 @@ class PostgresBroker(Broker): | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     @property | 
		
	
		
			
				|  |  |  |  |     def query_set(self) -> QuerySet: | 
		
	
		
			
				|  |  |  |  |         return Queue.objects.using(self.db_alias) | 
		
	
		
			
				|  |  |  |  |         return MQueue.objects.using(self.db_alias) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer: | 
		
	
		
			
				|  |  |  |  |         self.declare_queue(queue_name) | 
		
	
	
		
			
				
					
					|  |  |  | @ -119,7 +118,7 @@ class PostgresBroker(Broker): | 
		
	
		
			
				|  |  |  |  |         self.query_set.create( | 
		
	
		
			
				|  |  |  |  |             message_id=message.message_id, | 
		
	
		
			
				|  |  |  |  |             queue_name=message.queue_name, | 
		
	
		
			
				|  |  |  |  |             state=Queue.State.QUEUED, | 
		
	
		
			
				|  |  |  |  |             state=MQueue.State.QUEUED, | 
		
	
		
			
				|  |  |  |  |             message=message.encode(), | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.emit_after("enqueue", message, delay) | 
		
	
	
		
			
				
					
					|  |  |  | @ -154,7 +153,7 @@ class PostgresBroker(Broker): | 
		
	
		
			
				|  |  |  |  |             if ( | 
		
	
		
			
				|  |  |  |  |                 self.query_set.filter( | 
		
	
		
			
				|  |  |  |  |                     queue_name=queue_name, | 
		
	
		
			
				|  |  |  |  |                     state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), | 
		
	
		
			
				|  |  |  |  |                     state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), | 
		
	
		
			
				|  |  |  |  |                 ) | 
		
	
		
			
				|  |  |  |  |                 == 0 | 
		
	
		
			
				|  |  |  |  |             ): | 
		
	
	
		
			
				
					
					|  |  |  | @ -185,7 +184,7 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     @property | 
		
	
		
			
				|  |  |  |  |     def query_set(self) -> QuerySet: | 
		
	
		
			
				|  |  |  |  |         return Queue.objects.using(self.db_alias) | 
		
	
		
			
				|  |  |  |  |         return MQueue.objects.using(self.db_alias) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     @property | 
		
	
		
			
				|  |  |  |  |     def listen_connection(self) -> DatabaseWrapper: | 
		
	
	
		
			
				
					
					|  |  |  | @ -207,9 +206,9 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |         self.query_set.filter( | 
		
	
		
			
				|  |  |  |  |             message_id=message.message_id, | 
		
	
		
			
				|  |  |  |  |             queue_name=message.queue_name, | 
		
	
		
			
				|  |  |  |  |             state=Queue.State.CONSUMED, | 
		
	
		
			
				|  |  |  |  |             state=MQueue.State.CONSUMED, | 
		
	
		
			
				|  |  |  |  |         ).update( | 
		
	
		
			
				|  |  |  |  |             state=Queue.State.DONE, | 
		
	
		
			
				|  |  |  |  |             state=MQueue.State.DONE, | 
		
	
		
			
				|  |  |  |  |             message=message.encode(), | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.in_processing.remove(message.message_id) | 
		
	
	
		
			
				
					
					|  |  |  | @ -220,9 +219,9 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |         self.query_set.filter( | 
		
	
		
			
				|  |  |  |  |             message_id=message.message_id, | 
		
	
		
			
				|  |  |  |  |             queue_name=message.queue_name, | 
		
	
		
			
				|  |  |  |  |             state__ne=Queue.State.REJECTED, | 
		
	
		
			
				|  |  |  |  |             state__ne=MQueue.State.REJECTED, | 
		
	
		
			
				|  |  |  |  |         ).update( | 
		
	
		
			
				|  |  |  |  |             state=Queue.State.REJECT, | 
		
	
		
			
				|  |  |  |  |             state=MQueue.State.REJECT, | 
		
	
		
			
				|  |  |  |  |             message=message.encode(), | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.in_processing.remove(message.message_id) | 
		
	
	
		
			
				
					
					|  |  |  | @ -232,14 +231,14 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |         self.query_set.filter( | 
		
	
		
			
				|  |  |  |  |             message_id__in=[message.message_id for message in messages], | 
		
	
		
			
				|  |  |  |  |         ).update( | 
		
	
		
			
				|  |  |  |  |             state=Queue.State.QUEUED, | 
		
	
		
			
				|  |  |  |  |             state=MQueue.State.QUEUED, | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         # We don't care about locks, requeue occurs on worker stop | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     def _fetch_pending_notifies(self) -> list[Notify]: | 
		
	
		
			
				|  |  |  |  |         self.logger.debug(f"Polling for lost messages in {self.queue_name}") | 
		
	
		
			
				|  |  |  |  |         notifies = self.query_set.filter( | 
		
	
		
			
				|  |  |  |  |             state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), queue_name=self.queue_name | 
		
	
		
			
				|  |  |  |  |             state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         channel = channel_name(self.connection, self.queue_name, "enqueue") | 
		
	
		
			
				|  |  |  |  |         return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies] | 
		
	
	
		
			
				
					
					|  |  |  | @ -252,7 +251,7 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     def _get_message_lock_id(self, message: Message) -> int: | 
		
	
		
			
				|  |  |  |  |         return _cast_lock_id( | 
		
	
		
			
				|  |  |  |  |             f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" | 
		
	
		
			
				|  |  |  |  |             f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}"  # noqa: E501 | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     def _consume_one(self, message: Message) -> bool: | 
		
	
	
		
			
				
					
					|  |  |  | @ -262,9 +261,10 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         result = ( | 
		
	
		
			
				|  |  |  |  |             self.query_set.filter( | 
		
	
		
			
				|  |  |  |  |                 message_id=message.message_id, state__in=(Queue.State.QUEUED, Queue.State.CONSUMED) | 
		
	
		
			
				|  |  |  |  |                 message_id=message.message_id, | 
		
	
		
			
				|  |  |  |  |                 state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), | 
		
	
		
			
				|  |  |  |  |             ) | 
		
	
		
			
				|  |  |  |  |             .update(state=Queue.State.CONSUMED, mtime=timezone.now()) | 
		
	
		
			
				|  |  |  |  |             .update(state=MQueue.State.CONSUMED, mtime=timezone.now()) | 
		
	
		
			
				|  |  |  |  |             .extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)]) | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         return result == 1 | 
		
	
	
		
			
				
					
					|  |  |  | @ -276,7 +276,7 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |         # If we don't have a connection yet, fetch missed notifications from the table directly | 
		
	
		
			
				|  |  |  |  |         if self._listen_connection is None: | 
		
	
		
			
				|  |  |  |  |             # We might miss a notification between the initial query and the first time we wait for | 
		
	
		
			
				|  |  |  |  |             # notitications, it doesn't matter because we re-fetch for missed messages later on. | 
		
	
		
			
				|  |  |  |  |             # notifications, it doesn't matter because we re-fetch for missed messages later on. | 
		
	
		
			
				|  |  |  |  |             self.notifies = self._fetch_pending_notifies() | 
		
	
		
			
				|  |  |  |  |             self.logger.debug( | 
		
	
		
			
				|  |  |  |  |                 f"Found {len(self.notifies)} pending messages in queue {self.queue_name}" | 
		
	
	
		
			
				
					
					|  |  |  | @ -340,7 +340,7 @@ class _PostgresConsumer(Consumer): | 
		
	
		
			
				|  |  |  |  |             return | 
		
	
		
			
				|  |  |  |  |         self.logger.debug("Running garbage collector") | 
		
	
		
			
				|  |  |  |  |         count = self.query_set.filter( | 
		
	
		
			
				|  |  |  |  |             state__in=(Queue.State.DONE, Queue.State.REJECTED), | 
		
	
		
			
				|  |  |  |  |             state__in=(MQueue.State.DONE, MQueue.State.REJECTED), | 
		
	
		
			
				|  |  |  |  |             mtime__lte=timezone.now() - timezone.timedelta(days=30), | 
		
	
		
			
				|  |  |  |  |         ).delete() | 
		
	
		
			
				|  |  |  |  |         self.logger.info(f"Purged {count} messages in all queues") | 
		
	
	
		
			
				
					
					| 
							
							
							
						 |  |  | 
 |