diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 0c5d130554..601a9b250e 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "github.com/gorilla/securecookie" + log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/radius/eap/tls" "layeh.com/radius" "layeh.com/radius/rfc2865" @@ -28,6 +29,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac res, newState := p.GetChallengeForType(st, nextChallengeToOffer) stm.SetEAPState(rst, newState) + log.Debug("EAP: encapsulating challenge") rres := r.Response(radius.CodeAccessChallenge) rfc2865.State_SetString(rres, rst) eapEncoded, err := res.Encode() @@ -36,7 +38,6 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac } rfc2869.EAPMessage_Set(rres, eapEncoded) p.setMessageAuthenticator(rres) - // debug.DebugPacket(rres) err = w.Write(rres) if err != nil { panic(err) diff --git a/internal/outpost/radius/eap/packet.go b/internal/outpost/radius/eap/packet.go index a2add9ac60..43e81f6928 100644 --- a/internal/outpost/radius/eap/packet.go +++ b/internal/outpost/radius/eap/packet.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "errors" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/tls" ) @@ -63,7 +63,7 @@ func Decode(raw []byte) (*Packet, error) { } packet.Payload = emptyPayload(packet.msgType) packet.rawPayload = raw[5:] - logrus.WithField("raw", debug.FormatBytes(raw[5:])).Debug("EAP decode raw") + log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw") err := packet.Payload.Decode(raw[5:]) if err != nil { return nil, err diff --git a/internal/outpost/radius/eap/tls/conn.go b/internal/outpost/radius/eap/tls/conn.go index 9f4a7e0362..77ecd4efa5 100644 --- a/internal/outpost/radius/eap/tls/conn.go +++ b/internal/outpost/radius/eap/tls/conn.go @@ -4,6 +4,8 @@ import ( "bytes" "net" "time" + + log "github.com/sirupsen/logrus" ) type TLSConnection struct { @@ -19,17 +21,38 @@ func NewTLSConnection(initialData []byte) TLSConnection { return c } -func (conn TLSConnection) Read(p []byte) (int, error) { return conn.reader.Read(p) } +func (conn TLSConnection) TLSData() []byte { + return conn.writer.Bytes() +} + +func (conn TLSConnection) UpdateData(data []byte) { + conn.reader.Reset() + conn.reader.Write(data) +} + +// ---- + +func (conn TLSConnection) Read(p []byte) (int, error) { + log.Debugf("TLS(buffer): Read: %d from %d", len(p), conn.reader.Len()) + for { + n, err := conn.reader.Read(p) + if n == 0 { + log.Debug("TLS(buffer): Attempted read from empty buffer, stalling...") + time.Sleep(1 * time.Second) + continue + } + return n, err + } +} + func (conn TLSConnection) Write(p []byte) (int, error) { + log.Debugf("TLS(buffer): Write: %d", len(p)) return conn.writer.Write(p) } + func (conn TLSConnection) Close() error { return nil } func (conn TLSConnection) LocalAddr() net.Addr { return nil } func (conn TLSConnection) RemoteAddr() net.Addr { return nil } func (conn TLSConnection) SetDeadline(t time.Time) error { return nil } func (conn TLSConnection) SetReadDeadline(t time.Time) error { return nil } func (conn TLSConnection) SetWriteDeadline(t time.Time) error { return nil } - -func (conn TLSConnection) TLSData() []byte { - return conn.writer.Bytes() -} diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index 292cc07c98..e07802be2f 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -1,10 +1,12 @@ package tls import ( + "context" "crypto/tls" "encoding/binary" "errors" "slices" + "time" log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/radius/eap/debug" @@ -27,6 +29,7 @@ func (p *Payload) Decode(raw []byte) error { } else { p.Data = raw[1:] } + log.WithField("raw", debug.FormatBytes(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw") return nil } @@ -68,69 +71,74 @@ func (p *Payload) Handle(stt any) (*Payload, State) { stt = NewState() } st := stt.(State) - log.WithField("flags", p.Flags).Debug("Got TLS Packet") if !st.HasStarted { st.HasStarted = true return &Payload{ Flags: FlagTLSStart, }, st } - if st.HasMore() { - return p.sendNextChunk(st) - } - log.WithField("raw", debug.FormatBytes(p.Data)).Debug("TLS: Decode raw") - - tc := NewTLSConnection(p.Data) if st.TLS == nil { - log.Debug("no TLS connection in state yet, starting connection") - st.TLS = tls.Server(tc, &tls.Config{ + log.Debug("TLS: no TLS connection in state yet, starting connection") + st.Conn = NewTLSConnection(p.Data) + st.TLS = tls.Server(st.Conn, &tls.Config{ GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { - log.Debugf("%+v\n", argHello) + log.Debugf("TLS: ClientHello: %+v\n", argHello) return nil, nil }, ClientAuth: tls.RequireAnyClientCert, Certificates: certs, }) - err := st.TLS.Handshake() - log.WithError(err).Debug("TLS: Handshake error") + st.Context, _ = context.WithTimeout(context.Background(), 30*time.Second) + go func() { + err := st.TLS.HandshakeContext(st.Context) + if err != nil { + log.WithError(err).Debug("TLS: Handshake error") + } + }() + } else if len(p.Data) > 0 { + log.Debug("TLS: Updating buffer with new TLS data from packet") + st.Conn.UpdateData(p.Data) } - return p.sendDataChunked(tc.TLSData(), st) + if st.HasMore() { + return p.sendNextChunk(st) + } + return p.startChunkedTransfer(st.Conn.TLSData(), st) } const maxChunkSize = 1000 -func (p *Payload) sendDataChunked(data []byte, st State) (*Payload, State) { +func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) { flags := FlagLengthIncluded var dataToSend []byte if len(data) > maxChunkSize { - log.WithField("length", len(data)).Debug("Data needs to be chunked") + log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked") flags += FlagMoreFragments dataToSend = data[:maxChunkSize] remainingData := data[maxChunkSize:] // Chunk remaining data into correct chunks and add them to the list st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(remainingData, maxChunkSize))...) - st.TotalPayloadSize = len(st.RemainingChunks) * maxChunkSize + st.TotalPayloadSize = len(data) } else { dataToSend = data } return &Payload{ Flags: flags, - Length: uint32(len(data) + 5), + Length: uint32(st.TotalPayloadSize), Data: dataToSend, }, st } func (p *Payload) sendNextChunk(st State) (*Payload, State) { - log.Debug("Sending next chunk") + log.Debug("TLS: Sending next chunk") nextChunk := st.RemainingChunks[0] st.RemainingChunks = st.RemainingChunks[1:] flags := FlagLengthIncluded if st.HasMore() { - log.WithField("chunks", len(st.RemainingChunks)).Debug("More chunks left") + log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left") flags += FlagMoreFragments } - log.WithField("length", st.TotalPayloadSize).Debug("Total payload size") + log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size") return &Payload{ Flags: flags, Length: uint32(st.TotalPayloadSize), diff --git a/internal/outpost/radius/eap/tls/state.go b/internal/outpost/radius/eap/tls/state.go index 26874a9d5c..960742a6ee 100644 --- a/internal/outpost/radius/eap/tls/state.go +++ b/internal/outpost/radius/eap/tls/state.go @@ -1,12 +1,17 @@ package tls -import "crypto/tls" +import ( + "context" + "crypto/tls" +) type State struct { HasStarted bool RemainingChunks [][]byte TotalPayloadSize int TLS *tls.Conn + Conn TLSConnection + Context context.Context } func NewState() State {