From 15be3f246126cb2381fef1811c41f7ddaa54be95 Mon Sep 17 00:00:00 2001 From: "Jens L." Date: Fri, 20 Dec 2024 19:18:50 +0100 Subject: [PATCH] internal: fix URL generation for websocket connection (#12439) Signed-off-by: Jens Langhammer --- internal/outpost/ak/api_ws.go | 14 ++++++++++--- internal/outpost/ak/api_ws_test.go | 32 ++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 internal/outpost/ak/api_ws_test.go diff --git a/internal/outpost/ak/api_ws.go b/internal/outpost/ak/api_ws.go index 3b450736e6..dfab5255ac 100644 --- a/internal/outpost/ak/api_ws.go +++ b/internal/outpost/ak/api_ws.go @@ -16,11 +16,19 @@ import ( "goauthentik.io/internal/constants" ) +func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string) *url.URL { + wsUrl := &url.URL{} + wsUrl.Scheme = strings.ReplaceAll(akURL.Scheme, "http", "ws") + wsUrl.Host = akURL.Host + _p, _ := url.JoinPath(akURL.Path, "ws/outpost/", outpostUUID) + wsUrl.Path = _p + wsUrl.RawQuery = akURL.Query().Encode() + return wsUrl +} + func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { - pathTemplate := "%s://%s%sws/outpost/%s/?%s" query := akURL.Query() query.Set("instance_uuid", ac.instanceUUID.String()) - scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws") authHeader := fmt.Sprintf("Bearer %s", ac.token) @@ -37,7 +45,7 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { }, } - ws, _, err := dialer.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, akURL.Path, outpostUUID, akURL.Query().Encode()), header) + ws, _, err := dialer.Dial(ac.getWebsocketURL(akURL, outpostUUID).String(), header) if err != nil { ac.logger.WithError(err).Warning("failed to connect websocket") return err diff --git a/internal/outpost/ak/api_ws_test.go b/internal/outpost/ak/api_ws_test.go new file mode 100644 index 0000000000..33284795c0 --- /dev/null +++ b/internal/outpost/ak/api_ws_test.go @@ -0,0 +1,32 @@ +package ak + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func URLMustParse(u string) *url.URL { + ur, err := url.Parse(u) + if err != nil { + panic(err) + } + return ur +} + +func TestWebsocketURL(t *testing.T) { + u := URLMustParse("http://localhost:9000?foo=bar") + uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" + ac := &APIController{} + nu := ac.getWebsocketURL(*u, uuid) + assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77?foo=bar", nu.String()) +} + +func TestWebsocketURL_Subpath(t *testing.T) { + u := URLMustParse("http://localhost:9000/foo/bar/") + uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" + ac := &APIController{} + nu := ac.getWebsocketURL(*u, uuid) + assert.Equal(t, "ws://localhost:9000/foo/bar/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77", nu.String()) +}