outposts: fix update signal not being sent to correct instances
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
		| @ -40,7 +40,7 @@ class WebsocketMessage: | |||||||
| class OutpostConsumer(AuthJsonConsumer): | class OutpostConsumer(AuthJsonConsumer): | ||||||
|     """Handler for Outposts that connect over websockets for health checks and live updates""" |     """Handler for Outposts that connect over websockets for health checks and live updates""" | ||||||
|  |  | ||||||
|     outpost: Optional[Outpost] = None |     outpost: Outpost | ||||||
|  |  | ||||||
|     last_uid: Optional[str] = None |     last_uid: Optional[str] = None | ||||||
|  |  | ||||||
| @ -64,7 +64,9 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|     # pylint: disable=unused-argument |     # pylint: disable=unused-argument | ||||||
|     def disconnect(self, close_code): |     def disconnect(self, close_code): | ||||||
|         if self.outpost and self.last_uid: |         if self.outpost and self.last_uid: | ||||||
|             OutpostState.for_channel(self.outpost, self.last_uid).delete() |             state = OutpostState.for_instance_uid(self.outpost, self.last_uid) | ||||||
|  |             state.channel_ids.remove(self.channel_name) | ||||||
|  |             state.save() | ||||||
|         LOGGER.debug( |         LOGGER.debug( | ||||||
|             "removed outpost instance from cache", |             "removed outpost instance from cache", | ||||||
|             outpost=self.outpost, |             outpost=self.outpost, | ||||||
| @ -75,12 +77,10 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|         msg = from_dict(WebsocketMessage, content) |         msg = from_dict(WebsocketMessage, content) | ||||||
|         uid = msg.args.get("uuid", self.channel_name) |         uid = msg.args.get("uuid", self.channel_name) | ||||||
|         self.last_uid = uid |         self.last_uid = uid | ||||||
|         state = OutpostState( |         state = OutpostState.for_instance_uid(self.outpost, uid) | ||||||
|             uid=uid, |         if self.channel_name not in state.channel_ids: | ||||||
|             channel_id=self.channel_name, |             state.channel_ids.append(self.channel_name) | ||||||
|             last_seen=datetime.now(), |         state.last_seen = datetime.now() | ||||||
|             _outpost=self.outpost, |  | ||||||
|         ) |  | ||||||
|         if msg.instruction == WebsocketMessageInstruction.HELLO: |         if msg.instruction == WebsocketMessageInstruction.HELLO: | ||||||
|             state.version = msg.args.get("version", None) |             state.version = msg.args.get("version", None) | ||||||
|             state.build_hash = msg.args.get("buildHash", "") |             state.build_hash = msg.args.get("buildHash", "") | ||||||
|  | |||||||
| @ -409,7 +409,7 @@ class OutpostState: | |||||||
|     """Outpost instance state, last_seen and version""" |     """Outpost instance state, last_seen and version""" | ||||||
|  |  | ||||||
|     uid: str |     uid: str | ||||||
|     channel_id: str |     channel_ids: list[str] = field(default_factory=list) | ||||||
|     last_seen: Optional[datetime] = field(default=None) |     last_seen: Optional[datetime] = field(default=None) | ||||||
|     version: Optional[str] = field(default=None) |     version: Optional[str] = field(default=None) | ||||||
|     version_should: Union[Version, LegacyVersion] = field(default=OUR_VERSION) |     version_should: Union[Version, LegacyVersion] = field(default=OUR_VERSION) | ||||||
| @ -432,21 +432,20 @@ class OutpostState: | |||||||
|         keys = cache.keys(f"{outpost.state_cache_prefix}_*") |         keys = cache.keys(f"{outpost.state_cache_prefix}_*") | ||||||
|         states = [] |         states = [] | ||||||
|         for key in keys: |         for key in keys: | ||||||
|             channel = key.replace(f"{outpost.state_cache_prefix}_", "") |             instance_uid = key.replace(f"{outpost.state_cache_prefix}_", "") | ||||||
|             states.append(OutpostState.for_channel(outpost, channel)) |             states.append(OutpostState.for_instance_uid(outpost, instance_uid)) | ||||||
|         return states |         return states | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def for_channel(outpost: Outpost, channel: str) -> "OutpostState": |     def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState": | ||||||
|         """Get state for a single channel""" |         """Get state for a single instance""" | ||||||
|         key = f"{outpost.state_cache_prefix}_{channel}" |         key = f"{outpost.state_cache_prefix}_{uid}" | ||||||
|         default_data = {"uid": channel, "channel_id": channel} |         default_data = {"uid": uid, "channel_ids": []} | ||||||
|         data = cache.get(key, default_data) |         data = cache.get(key, default_data) | ||||||
|         if isinstance(data, str): |         if isinstance(data, str): | ||||||
|             cache.delete(key) |             cache.delete(key) | ||||||
|             data = default_data |             data = default_data | ||||||
|         state = from_dict(OutpostState, data) |         state = from_dict(OutpostState, data) | ||||||
|         state.uid = channel |  | ||||||
|         # pylint: disable=protected-access |         # pylint: disable=protected-access | ||||||
|         state._outpost = outpost |         state._outpost = outpost | ||||||
|         return state |         return state | ||||||
|  | |||||||
| @ -202,8 +202,11 @@ def _outpost_single_update(outpost: Outpost, layer=None): | |||||||
|     if not layer:  # pragma: no cover |     if not layer:  # pragma: no cover | ||||||
|         layer = get_channel_layer() |         layer = get_channel_layer() | ||||||
|     for state in OutpostState.for_outpost(outpost): |     for state in OutpostState.for_outpost(outpost): | ||||||
|         LOGGER.debug("sending update", channel=state.channel_id, outpost=outpost) |         for channel in state.channel_ids: | ||||||
|         async_to_sync(layer.send)(state.channel_id, {"type": "event.update"}) |             LOGGER.debug( | ||||||
|  |                 "sending update", channel=channel, instance=state.uid, outpost=outpost | ||||||
|  |             ) | ||||||
|  |             async_to_sync(layer.send)(channel, {"type": "event.update"}) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @CELERY_APP.task() | ||||||
|  | |||||||
| @ -207,7 +207,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | |||||||
| 	} | 	} | ||||||
| 	p.ClearCSRFCookie(rw, req) | 	p.ClearCSRFCookie(rw, req) | ||||||
| 	if c.Value != nonce { | 	if c.Value != nonce { | ||||||
| 		p.logger.WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack") | 		p.logger.WithField("is", c.Value).WithField("should", nonce).WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack") | ||||||
| 		p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed") | 		p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer