84 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """websocket proxy consumer"""
 | |
| import threading
 | |
| from ssl import CERT_NONE
 | |
| 
 | |
| import websocket
 | |
| from channels.generic.websocket import WebsocketConsumer
 | |
| from structlog import get_logger
 | |
| 
 | |
| from passbook.app_gw.models import ApplicationGatewayProvider
 | |
| 
 | |
| LOGGER = get_logger(__name__)
 | |
| 
 | |
| class ProxyConsumer(WebsocketConsumer):
 | |
|     """Proxy websocket connection to upstream"""
 | |
| 
 | |
|     _headers_dict = {}
 | |
|     _app_gw = None
 | |
|     _client = None
 | |
|     _thread = None
 | |
| 
 | |
|     def _fix_headers(self, input_dict):
 | |
|         """Fix headers from bytestrings to normal strings"""
 | |
|         return {
 | |
|             key.decode('utf-8'): value.decode('utf-8')
 | |
|             for key, value in dict(input_dict).items()
 | |
|         }
 | |
| 
 | |
|     def connect(self):
 | |
|         """Extract host header, lookup in database and proxy connection"""
 | |
|         self._headers_dict = self._fix_headers(dict(self.scope.get('headers')))
 | |
|         host = self._headers_dict.pop('host')
 | |
|         query_string = self.scope.get('query_string').decode('utf-8')
 | |
|         matches = ApplicationGatewayProvider.objects.filter(
 | |
|             server_name__contains=[host],
 | |
|             enabled=True)
 | |
|         if matches.exists():
 | |
|             self._app_gw = matches.first()
 | |
|             # TODO: Get upstream that starts with wss or
 | |
|             upstream = self._app_gw.upstream[0].replace('http', 'ws') + self.scope.get('path')
 | |
|             if query_string:
 | |
|                 upstream += '?' + query_string
 | |
|             sslopt = {}
 | |
|             if not self._app_gw.upstream_ssl_verification:
 | |
|                 sslopt = {"cert_reqs": CERT_NONE}
 | |
|             self._client = websocket.WebSocketApp(
 | |
|                 url=upstream,
 | |
|                 subprotocols=self.scope.get('subprotocols'),
 | |
|                 header=self._headers_dict,
 | |
|                 on_message=self._client_on_message_handler(),
 | |
|                 on_error=self._client_on_error_handler(),
 | |
|                 on_close=self._client_on_close_handler(),
 | |
|                 on_open=self._client_on_open_handler())
 | |
|             LOGGER.debug("Accepting connection for %s", host)
 | |
|             self._thread = threading.Thread(target=lambda: self._client.run_forever(sslopt=sslopt))
 | |
|             self._thread.start()
 | |
| 
 | |
|     def _client_on_open_handler(self):
 | |
|         return lambda ws: self.accept(self._client.sock.handshake_response.subprotocol)
 | |
| 
 | |
|     def _client_on_message_handler(self):
 | |
|         # pylint: disable=unused-argument,invalid-name
 | |
|         def message_handler(ws, message):
 | |
|             if isinstance(message, str):
 | |
|                 self.send(text_data=message)
 | |
|             else:
 | |
|                 self.send(bytes_data=message)
 | |
|         return message_handler
 | |
| 
 | |
|     def _client_on_error_handler(self):
 | |
|         return lambda ws, error: print(error)
 | |
| 
 | |
|     def _client_on_close_handler(self):
 | |
|         return lambda ws: self.disconnect(0)
 | |
| 
 | |
|     def disconnect(self, code):
 | |
|         self._client.close()
 | |
| 
 | |
|     def receive(self, text_data=None, bytes_data=None):
 | |
|         if text_data:
 | |
|             opcode = websocket.ABNF.OPCODE_TEXT
 | |
|         if bytes_data:
 | |
|             opcode = websocket.ABNF.OPCODE_BINARY
 | |
|         self._client.send(text_data or bytes_data, opcode)
 | 
