From df21e678d6603977cf11dc23756607d8042b3921 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Thu, 15 May 2025 01:21:12 +0200 Subject: [PATCH] fix a bunch more Signed-off-by: Jens Langhammer --- internal/outpost/radius/eap/handler.go | 1 + internal/outpost/radius/eap/tls/conn.go | 49 +++++++++++++++---- internal/outpost/radius/eap/tls/payload.go | 56 +++++++++++++++++----- internal/outpost/radius/eap/tls/state.go | 6 +-- internal/outpost/radius/handler.go | 1 + 5 files changed, 89 insertions(+), 24 deletions(-) diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 79d00b0931..af34a51868 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -20,6 +20,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac } st := stm.GetEAPState(rst) if st == nil { + log.Debug("EAP: blank state") st = BlankState(stm.GetEAPSettings()) } if len(st.ChallengesToOffer) < 1 { diff --git a/internal/outpost/radius/eap/tls/conn.go b/internal/outpost/radius/eap/tls/conn.go index b6b2dd97ca..97b23e5d19 100644 --- a/internal/outpost/radius/eap/tls/conn.go +++ b/internal/outpost/radius/eap/tls/conn.go @@ -11,17 +11,19 @@ import ( type TLSConnection struct { reader *bytes.Buffer writer *bytes.Buffer + + bufferIncomingBytesCount uint32 } -func NewTLSConnection(initialData []byte) TLSConnection { - c := TLSConnection{ +func NewTLSConnection(initialData []byte) *TLSConnection { + c := &TLSConnection{ reader: bytes.NewBuffer(initialData), writer: bytes.NewBuffer([]byte{}), } return c } -func (conn TLSConnection) GetData() []byte { +func (conn TLSConnection) OutboundData() []byte { for { b := conn.writer.Bytes() if len(b) < 1 { @@ -34,21 +36,50 @@ func (conn TLSConnection) GetData() []byte { } func (conn TLSConnection) UpdateData(data []byte) { - conn.reader.Reset() conn.reader.Write(data) + if conn.bufferIncomingBytesCount > 0 && conn.reader.Len() == int(conn.bufferIncomingBytesCount) { + conn.bufferIncomingBytesCount = 0 + } + log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.reader.Len(), conn.bufferIncomingBytesCount) } -// ---- +// func (conn TLSConnection) Reset() { +// log.Debug("TLS(buffer): reset") +// conn.reader.Reset() +// conn.writer.Reset() +// } -func (conn TLSConnection) Read(p []byte) (int, error) { - log.Debugf("TLS(buffer): Read: %d from %d", len(p), conn.reader.Len()) +func (conn TLSConnection) NeedsMoreData() bool { + if conn.bufferIncomingBytesCount > 0 { + return conn.reader.Len() < int(conn.bufferIncomingBytesCount) + } + return false +} + +// func (conn TLSConnection) WaitForAttemptedRead() int { +// for { +// // log.Debug("TLS(buffer): waiting for attempted read") +// if conn.missingBytes == 0 { +// continue +// } +// return conn.missingBytes +// } +// } + +func (conn *TLSConnection) Read(p []byte) (int, error) { for { n, err := conn.reader.Read(p) if n == 0 { - log.Debug("TLS(buffer): Attempted read from empty buffer, stalling...") - time.Sleep(1 * time.Second) + log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p)) + time.Sleep(500 * time.Millisecond) continue } + if conn.reader.Len() < int(conn.bufferIncomingBytesCount) { + log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.reader.Len()-int(conn.bufferIncomingBytesCount)) + time.Sleep(500 * time.Millisecond) + continue + } + log.Debugf("TLS(buffer): Read: %d from %d", len(p), n) return n, err } } diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index b6bf402552..6194e86175 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -12,6 +12,8 @@ import ( "goauthentik.io/internal/outpost/radius/eap/debug" ) +const maxChunkSize = 1000 + type Payload struct { Flags Flag Length uint32 @@ -20,16 +22,17 @@ type Payload struct { func (p *Payload) Decode(raw []byte) error { p.Flags = Flag(raw[0]) + raw = raw[1:] if p.Flags&FlagLengthIncluded != 0 { if len(raw) < 4 { return errors.New("invalid size") } p.Length = binary.BigEndian.Uint32(raw) - p.Data = raw[5:] + p.Data = raw[4:] } else { - p.Data = raw[1:] + p.Data = raw[0:] } - log.WithField("raw", debug.FormatBytes(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw") + log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw") return nil } @@ -66,12 +69,14 @@ func init() { certs = append(certs, cert) } -func (p *Payload) Handle(stt any) (*Payload, State) { +func (p *Payload) Handle(stt any) (*Payload, *State) { if stt == nil { + log.Debug("TLS: new state") stt = NewState() } - st := stt.(State) + st := stt.(*State) if !st.HasStarted { + log.Debug("TLS: handshake starting") st.HasStarted = true return &Payload{ Flags: FlagTLSStart, @@ -89,8 +94,10 @@ func (p *Payload) Handle(stt any) (*Payload, State) { ClientAuth: tls.RequireAnyClientCert, Certificates: certs, }) - st.Context, _ = context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + st.Context = ctx go func() { + defer cancel() err := st.TLS.HandshakeContext(st.Context) if err != nil { log.WithError(err).Debug("TLS: Handshake error") @@ -98,17 +105,35 @@ func (p *Payload) Handle(stt any) (*Payload, State) { }() } else if len(p.Data) > 0 { log.Debug("TLS: Updating buffer with new TLS data from packet") + if p.Flags&FlagLengthMore != 0 && st.Conn.bufferIncomingBytesCount == 0 { + log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length) + st.Conn.bufferIncomingBytesCount = p.Length + } st.Conn.UpdateData(p.Data) + return &Payload{ + Flags: FlagNone, + Length: 0, + Data: []byte{}, + }, st + } + // If we need more data, send the client the go-ahead + if st.Conn.NeedsMoreData() { + return &Payload{ + Flags: FlagNone, + Length: 0, + Data: []byte{}, + }, st } if st.HasMore() { return p.sendNextChunk(st) } - return p.startChunkedTransfer(st.Conn.GetData(), st) + if len(st.Conn.OutboundData()) > 0 { + return p.startChunkedTransfer(st.Conn.OutboundData(), st) + } + panic("we shouldn't get here") } -const maxChunkSize = 1000 - -func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) { +func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) { flags := FlagLengthIncluded var dataToSend []byte if len(data) > maxChunkSize { @@ -129,14 +154,21 @@ func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) }, st } -func (p *Payload) sendNextChunk(st State) (*Payload, State) { - log.Debug("TLS: Sending next chunk") +func (p *Payload) sendNextChunk(st *State) (*Payload, *State) { nextChunk := st.RemainingChunks[0] + log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk") st.RemainingChunks = st.RemainingChunks[1:] flags := FlagLengthIncluded if st.HasMore() { log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left") flags += FlagMoreFragments + } else { + // Last chunk, reset the connection buffers and pending payload size + defer func() { + log.Debug("TLS: Sent last chunk") + st.Conn.reader.Reset() + st.TotalPayloadSize = 0 + }() } log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size") return &Payload{ diff --git a/internal/outpost/radius/eap/tls/state.go b/internal/outpost/radius/eap/tls/state.go index 960742a6ee..04d1c8faf3 100644 --- a/internal/outpost/radius/eap/tls/state.go +++ b/internal/outpost/radius/eap/tls/state.go @@ -10,12 +10,12 @@ type State struct { RemainingChunks [][]byte TotalPayloadSize int TLS *tls.Conn - Conn TLSConnection + Conn *TLSConnection Context context.Context } -func NewState() State { - return State{ +func NewState() *State { + return &State{ RemainingChunks: make([][]byte, 0), } } diff --git a/internal/outpost/radius/handler.go b/internal/outpost/radius/handler.go index 021d083fb6..06cbe7cb97 100644 --- a/internal/outpost/radius/handler.go +++ b/internal/outpost/radius/handler.go @@ -60,6 +60,7 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request) "code": r.Code.String(), "request": rid, "ip": host, + "id": r.Identifier, }) selectedApp := "" defer func() {