internal: fix missing trailing slash in outpost websocket (#12470) Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens L. <jens@goauthentik.io>
This commit is contained in:
![98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com](/assets/img/avatar_default.png)
committed by
GitHub

parent
e87a17fd81
commit
09b3a1d0bd
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -16,13 +17,16 @@ import (
|
|||||||
"goauthentik.io/internal/constants"
|
"goauthentik.io/internal/constants"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string) *url.URL {
|
func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string, query url.Values) *url.URL {
|
||||||
wsUrl := &url.URL{}
|
wsUrl := &url.URL{}
|
||||||
wsUrl.Scheme = strings.ReplaceAll(akURL.Scheme, "http", "ws")
|
wsUrl.Scheme = strings.ReplaceAll(akURL.Scheme, "http", "ws")
|
||||||
wsUrl.Host = akURL.Host
|
wsUrl.Host = akURL.Host
|
||||||
_p, _ := url.JoinPath(akURL.Path, "ws/outpost/", outpostUUID)
|
_p, _ := url.JoinPath(akURL.Path, "ws/outpost/", outpostUUID, "/")
|
||||||
wsUrl.Path = _p
|
wsUrl.Path = _p
|
||||||
wsUrl.RawQuery = akURL.Query().Encode()
|
v := url.Values{}
|
||||||
|
maps.Insert(v, maps.All(akURL.Query()))
|
||||||
|
maps.Insert(v, maps.All(query))
|
||||||
|
wsUrl.RawQuery = v.Encode()
|
||||||
return wsUrl
|
return wsUrl
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,7 +49,9 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ws, _, err := dialer.Dial(ac.getWebsocketURL(akURL, outpostUUID).String(), header)
|
wsu := ac.getWebsocketURL(akURL, outpostUUID, query).String()
|
||||||
|
ac.logger.WithField("url", wsu).Debug("connecting to websocket")
|
||||||
|
ws, _, err := dialer.Dial(wsu, header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ac.logger.WithError(err).Warning("failed to connect websocket")
|
ac.logger.WithError(err).Warning("failed to connect websocket")
|
||||||
return err
|
return err
|
||||||
|
@ -19,14 +19,24 @@ func TestWebsocketURL(t *testing.T) {
|
|||||||
u := URLMustParse("http://localhost:9000?foo=bar")
|
u := URLMustParse("http://localhost:9000?foo=bar")
|
||||||
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
||||||
ac := &APIController{}
|
ac := &APIController{}
|
||||||
nu := ac.getWebsocketURL(*u, uuid)
|
nu := ac.getWebsocketURL(*u, uuid, url.Values{})
|
||||||
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77?foo=bar", nu.String())
|
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketURL_Query(t *testing.T) {
|
||||||
|
u := URLMustParse("http://localhost:9000?foo=bar")
|
||||||
|
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
||||||
|
ac := &APIController{}
|
||||||
|
v := url.Values{}
|
||||||
|
v.Set("bar", "baz")
|
||||||
|
nu := ac.getWebsocketURL(*u, uuid, v)
|
||||||
|
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebsocketURL_Subpath(t *testing.T) {
|
func TestWebsocketURL_Subpath(t *testing.T) {
|
||||||
u := URLMustParse("http://localhost:9000/foo/bar/")
|
u := URLMustParse("http://localhost:9000/foo/bar/")
|
||||||
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
|
||||||
ac := &APIController{}
|
ac := &APIController{}
|
||||||
nu := ac.getWebsocketURL(*u, uuid)
|
nu := ac.getWebsocketURL(*u, uuid, url.Values{})
|
||||||
assert.Equal(t, "ws://localhost:9000/foo/bar/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77", nu.String())
|
assert.Equal(t, "ws://localhost:9000/foo/bar/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/", nu.String())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user