internal: fix missing trailing slash in outpost websocket (#12470) Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
![98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com](/assets/img/avatar_default.png)
committed by
GitHub

parent
e87a17fd81
commit
09b3a1d0bd
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@ -16,13 +17,16 @@ import (
|
||||
"goauthentik.io/internal/constants"
|
||||
)
|
||||
|
||||
func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string) *url.URL {
|
||||
func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string, query url.Values) *url.URL {
|
||||
wsUrl := &url.URL{}
|
||||
wsUrl.Scheme = strings.ReplaceAll(akURL.Scheme, "http", "ws")
|
||||
wsUrl.Host = akURL.Host
|
||||
_p, _ := url.JoinPath(akURL.Path, "ws/outpost/", outpostUUID)
|
||||
_p, _ := url.JoinPath(akURL.Path, "ws/outpost/", outpostUUID, "/")
|
||||
wsUrl.Path = _p
|
||||
wsUrl.RawQuery = akURL.Query().Encode()
|
||||
v := url.Values{}
|
||||
maps.Insert(v, maps.All(akURL.Query()))
|
||||
maps.Insert(v, maps.All(query))
|
||||
wsUrl.RawQuery = v.Encode()
|
||||
return wsUrl
|
||||
}
|
||||
|
||||
@ -45,7 +49,9 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
|
||||
},
|
||||
}
|
||||
|
||||
ws, _, err := dialer.Dial(ac.getWebsocketURL(akURL, outpostUUID).String(), header)
|
||||
wsu := ac.getWebsocketURL(akURL, outpostUUID, query).String()
|
||||
ac.logger.WithField("url", wsu).Debug("connecting to websocket")
|
||||
ws, _, err := dialer.Dial(wsu, header)
|
||||
if err != nil {
|
||||
ac.logger.WithError(err).Warning("failed to connect websocket")
|
||||
return err
|
||||
|
@ -19,14 +19,24 @@ func TestWebsocketURL(t *testing.T) {
|
||||
u := URLMustParse("http://localhost:9000?foo=bar")
|
||||
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
||||
ac := &APIController{}
|
||||
nu := ac.getWebsocketURL(*u, uuid)
|
||||
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77?foo=bar", nu.String())
|
||||
nu := ac.getWebsocketURL(*u, uuid, url.Values{})
|
||||
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String())
|
||||
}
|
||||
|
||||
func TestWebsocketURL_Query(t *testing.T) {
|
||||
u := URLMustParse("http://localhost:9000?foo=bar")
|
||||
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
||||
ac := &APIController{}
|
||||
v := url.Values{}
|
||||
v.Set("bar", "baz")
|
||||
nu := ac.getWebsocketURL(*u, uuid, v)
|
||||
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String())
|
||||
}
|
||||
|
||||
func TestWebsocketURL_Subpath(t *testing.T) {
|
||||
u := URLMustParse("http://localhost:9000/foo/bar/")
|
||||
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
||||
ac := &APIController{}
|
||||
nu := ac.getWebsocketURL(*u, uuid)
|
||||
assert.Equal(t, "ws://localhost:9000/foo/bar/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77", nu.String())
|
||||
nu := ac.getWebsocketURL(*u, uuid, url.Values{})
|
||||
assert.Equal(t, "ws://localhost:9000/foo/bar/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/", nu.String())
|
||||
}
|
||||
|
Reference in New Issue
Block a user