84 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""websocket proxy consumer"""
 | 
						|
import threading
 | 
						|
from logging import getLogger
 | 
						|
from ssl import CERT_NONE
 | 
						|
 | 
						|
import websocket
 | 
						|
from channels.generic.websocket import WebsocketConsumer
 | 
						|
 | 
						|
from passbook.app_gw.models import ApplicationGatewayProvider
 | 
						|
 | 
						|
LOGGER = getLogger(__name__)
 | 
						|
 | 
						|
class ProxyConsumer(WebsocketConsumer):
 | 
						|
    """Proxy websocket connection to upstream"""
 | 
						|
 | 
						|
    _headers_dict = {}
 | 
						|
    _app_gw = None
 | 
						|
    _client = None
 | 
						|
    _thread = None
 | 
						|
 | 
						|
    def _fix_headers(self, input_dict):
 | 
						|
        """Fix headers from bytestrings to normal strings"""
 | 
						|
        return {
 | 
						|
            key.decode('utf-8'): value.decode('utf-8')
 | 
						|
            for key, value in dict(input_dict).items()
 | 
						|
        }
 | 
						|
 | 
						|
    def connect(self):
 | 
						|
        """Extract host header, lookup in database and proxy connection"""
 | 
						|
        self._headers_dict = self._fix_headers(dict(self.scope.get('headers')))
 | 
						|
        host = self._headers_dict.pop('host')
 | 
						|
        query_string = self.scope.get('query_string').decode('utf-8')
 | 
						|
        matches = ApplicationGatewayProvider.objects.filter(
 | 
						|
            server_name__contains=[host],
 | 
						|
            enabled=True)
 | 
						|
        if matches.exists():
 | 
						|
            self._app_gw = matches.first()
 | 
						|
            # TODO: Get upstream that starts with wss or
 | 
						|
            upstream = self._app_gw.upstream[0].replace('http', 'ws') + self.scope.get('path')
 | 
						|
            if query_string:
 | 
						|
                upstream += '?' + query_string
 | 
						|
            sslopt = {}
 | 
						|
            if not self._app_gw.upstream_ssl_verification:
 | 
						|
                sslopt = {"cert_reqs": CERT_NONE}
 | 
						|
            self._client = websocket.WebSocketApp(
 | 
						|
                url=upstream,
 | 
						|
                subprotocols=self.scope.get('subprotocols'),
 | 
						|
                header=self._headers_dict,
 | 
						|
                on_message=self._client_on_message_handler(),
 | 
						|
                on_error=self._client_on_error_handler(),
 | 
						|
                on_close=self._client_on_close_handler(),
 | 
						|
                on_open=self._client_on_open_handler())
 | 
						|
            LOGGER.debug("Accepting connection for %s", host)
 | 
						|
            self._thread = threading.Thread(target=lambda: self._client.run_forever(sslopt=sslopt))
 | 
						|
            self._thread.start()
 | 
						|
 | 
						|
    def _client_on_open_handler(self):
 | 
						|
        return lambda ws: self.accept(self._client.sock.handshake_response.subprotocol)
 | 
						|
 | 
						|
    def _client_on_message_handler(self):
 | 
						|
        # pylint: disable=unused-argument,invalid-name
 | 
						|
        def message_handler(ws, message):
 | 
						|
            if isinstance(message, str):
 | 
						|
                self.send(text_data=message)
 | 
						|
            else:
 | 
						|
                self.send(bytes_data=message)
 | 
						|
        return message_handler
 | 
						|
 | 
						|
    def _client_on_error_handler(self):
 | 
						|
        return lambda ws, error: print(error)
 | 
						|
 | 
						|
    def _client_on_close_handler(self):
 | 
						|
        return lambda ws: self.disconnect(0)
 | 
						|
 | 
						|
    def disconnect(self, code):
 | 
						|
        self._client.close()
 | 
						|
 | 
						|
    def receive(self, text_data=None, bytes_data=None):
 | 
						|
        if text_data:
 | 
						|
            opcode = websocket.ABNF.OPCODE_TEXT
 | 
						|
        if bytes_data:
 | 
						|
            opcode = websocket.ABNF.OPCODE_BINARY
 | 
						|
        self._client.send(text_data or bytes_data, opcode)
 |