From a71532b3e3ef3c2d2fae09a384347cd6b4e0ea56 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Wed, 14 May 2025 15:18:40 +0200 Subject: [PATCH] refactor more Signed-off-by: Jens Langhammer --- internal/outpost/radius/eap/handler.go | 10 ++++---- internal/outpost/radius/eap/tls/conn.go | 12 ++++++++-- internal/outpost/radius/eap/tls/payload.go | 10 ++++---- internal/outpost/radius/handler.go | 27 +++++++++++++++++++--- internal/outpost/radius/radius.go | 2 +- 5 files changed, 46 insertions(+), 15 deletions(-) diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 601a9b250e..79d00b0931 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -29,13 +29,13 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac res, newState := p.GetChallengeForType(st, nextChallengeToOffer) stm.SetEAPState(rst, newState) - log.Debug("EAP: encapsulating challenge") rres := r.Response(radius.CodeAccessChallenge) rfc2865.State_SetString(rres, rst) eapEncoded, err := res.Encode() if err != nil { panic(err) } + log.WithField("length", len(eapEncoded)).Debug("EAP: encapsulating challenge") rfc2869.EAPMessage_Set(rres, eapEncoded) p.setMessageAuthenticator(rres) err = w.Write(rres) @@ -54,9 +54,11 @@ func (p *Packet) GetChallengeForType(st *State, t Type) (*Packet, *State) { var tst any switch t { case TypeTLS: - cp := tls.Payload{} - cp.Decode(p.rawPayload) - payload, tst = cp.Handle(st.TypeState[t]) + if _, ok := p.Payload.(*tls.Payload); !ok { + p.Payload = &tls.Payload{} + p.Payload.Decode(p.rawPayload) + } + payload, tst = p.Payload.(*tls.Payload).Handle(st.TypeState[t]) } st.TypeState[t] = tst res.Payload = payload.(Payload) diff --git a/internal/outpost/radius/eap/tls/conn.go b/internal/outpost/radius/eap/tls/conn.go index 77ecd4efa5..b6b2dd97ca 100644 --- a/internal/outpost/radius/eap/tls/conn.go +++ b/internal/outpost/radius/eap/tls/conn.go @@ -21,8 +21,16 @@ func NewTLSConnection(initialData []byte) TLSConnection { return c } -func (conn TLSConnection) TLSData() []byte { - return conn.writer.Bytes() +func (conn TLSConnection) GetData() []byte { + for { + b := conn.writer.Bytes() + if len(b) < 1 { + log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...") + time.Sleep(1 * time.Second) + continue + } + return b + } } func (conn TLSConnection) UpdateData(data []byte) { diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index e07802be2f..b6bf402552 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -103,7 +103,7 @@ func (p *Payload) Handle(stt any) (*Payload, State) { if st.HasMore() { return p.sendNextChunk(st) } - return p.startChunkedTransfer(st.Conn.TLSData(), st) + return p.startChunkedTransfer(st.Conn.GetData(), st) } const maxChunkSize = 1000 @@ -114,10 +114,10 @@ func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) if len(data) > maxChunkSize { log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked") flags += FlagMoreFragments - dataToSend = data[:maxChunkSize] - remainingData := data[maxChunkSize:] - // Chunk remaining data into correct chunks and add them to the list - st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(remainingData, maxChunkSize))...) + // Chunk data into correct chunks and add them to the list + st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...) + dataToSend = st.RemainingChunks[0] + st.RemainingChunks = st.RemainingChunks[1:] st.TotalPayloadSize = len(data) } else { dataToSend = data diff --git a/internal/outpost/radius/handler.go b/internal/outpost/radius/handler.go index 0b94e86743..021d083fb6 100644 --- a/internal/outpost/radius/handler.go +++ b/internal/outpost/radius/handler.go @@ -3,6 +3,7 @@ package radius import ( "crypto/sha512" "encoding/hex" + "net" "time" "github.com/getsentry/sentry-go" @@ -35,12 +36,31 @@ func (r *RadiusRequest) ID() string { return r.id } +type LogWriter struct { + w radius.ResponseWriter + l *log.Entry +} + +func (lw LogWriter) Write(packet *radius.Packet) error { + lw.l.WithField("code", packet.Code.String()).Info("Radius Response") + return lw.w.Write(packet) +} + func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request) { span := sentry.StartSpan(r.Context(), "authentik.providers.radius.connect", sentry.WithTransactionName("authentik.providers.radius.connect")) rid := uuid.New().String() span.SetTag("request_uid", rid) - rl := rs.log.WithField("code", r.Code.String()).WithField("request", rid) + host, _, err := net.SplitHostPort(r.RemoteAddr.String()) + if err != nil { + rs.log.WithError(err).Warning("Failed to get remote IP") + return + } + rl := rs.log.WithFields(log.Fields{ + "code": r.Code.String(), + "request": rid, + "ip": host, + }) selectedApp := "" defer func() { span.Finish() @@ -58,6 +78,7 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request) } rl.Info("Radius Request") + ww := LogWriter{w, rl} // Lookup provider by shared secret var pi *ProviderInstance @@ -72,12 +93,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request) hs := sha512.Sum512([]byte(r.Secret)) bs := hex.EncodeToString(hs[:]) nr.Log().WithField("hashed_secret", bs).Warning("No provider found") - _ = w.Write(r.Response(radius.CodeAccessReject)) + _ = ww.Write(r.Response(radius.CodeAccessReject)) return } nr.pi = pi if nr.Code == radius.CodeAccessRequest { - rs.Handle_AccessRequest(w, nr) + rs.Handle_AccessRequest(ww, nr) } } diff --git a/internal/outpost/radius/radius.go b/internal/outpost/radius/radius.go index e01c237c1f..9a753b5886 100644 --- a/internal/outpost/radius/radius.go +++ b/internal/outpost/radius/radius.go @@ -87,7 +87,7 @@ func (rs *RadiusServer) RADIUSSecret(ctx context.Context, remoteAddr net.Addr) ( return bi < bj }) candidate := matchedPrefixes[0] - rs.log.WithField("ip", ip.String()).WithField("cidr", candidate.c.String()).Debug("Matched CIDR") + rs.log.WithField("ip", ip.String()).WithField("cidr", candidate.c.String()).WithField("instance", candidate.p.appSlug).Debug("Matched CIDR") return candidate.p.SharedSecret, nil }