diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index f05ff37324..0457d0ff80 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -33,7 +33,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac rres := r.Response(radius.CodeAccessChallenge) if p, ok := res.Payload.(protocol.EmptyPayload); ok { - // This is a bit hacky here + // TODO: This is a bit hacky here res.code = CodeSuccess res.id -= 1 rres = p.ModifyPacket(rres) diff --git a/internal/outpost/radius/eap/tls/buff_conn.go b/internal/outpost/radius/eap/tls/buff_conn.go index 05e303625f..fe85f325f5 100644 --- a/internal/outpost/radius/eap/tls/buff_conn.go +++ b/internal/outpost/radius/eap/tls/buff_conn.go @@ -74,22 +74,26 @@ func (conn BuffConn) NeedsMoreData() bool { func (conn *BuffConn) Read(p []byte) (int, error) { d, err := retry.DoWithData( func() (int, error) { - n, err := conn.reader.Read(p) - if n == 0 { + if conn.reader.Len() == 0 { log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p)) return 0, errStall } - if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { - log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) - return 0, errStall - } - if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { - conn.expectedWriterByteCount = 0 + if conn.expectedWriterByteCount > 0 { + // If we're waiting for more data, we need to stall + if conn.writtenByteCount < int(conn.expectedWriterByteCount) { + log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) + return 0, errStall + } + // If we have all the data, reset how much we're expecting to still get + if conn.writtenByteCount == int(conn.expectedWriterByteCount) { + conn.expectedWriterByteCount = 0 + } } if conn.reader.Len() == 0 { conn.writtenByteCount = 0 } - log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n) + n, err := conn.reader.Read(p) + log.Debugf("TLS(buffcon): Read: %d into %d (total %d)", n, len(p), conn.reader.Len()) return n, err }, conn.retryOptions..., diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index 0a199bac4d..97b9c8ec16 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -89,33 +89,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { } if st.TLS == nil { - log.Debug("TLS: no TLS connection in state yet, starting connection") - st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) - st.Conn = NewBuffConn(p.Data, st.Context) - st.TLS = tls.Server(st.Conn, &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - log.Debugf("TLS: ClientHello: %+v\n", ch) - st.ClientHello = ch - return nil, nil - }, - ClientAuth: tls.RequireAnyClientCert, - Certificates: certs, - CipherSuites: []uint16{ - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_RC4_128_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - // tls.TLS_RSA_WITH_RC4_128_MD5, - }, - }) - go func() { - err := st.TLS.HandshakeContext(st.Context) - if err != nil { - log.WithError(err).Debug("TLS: Handshake error") - return - } - log.Debug("TLS: handshake done") - p.handshakeFinished(st) - }() + st = p.tlsInit(st) } else if len(p.Data) > 0 { log.Debug("TLS: Updating buffer with new TLS data from packet") if p.Flags&FlagLengthIncluded != 0 && st.Conn.expectedWriterByteCount == 0 { @@ -158,7 +132,33 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { return p.startChunkedTransfer(st.Conn.OutboundData(), st) } -func (p *Payload) handshakeFinished(st *State) { +func (p *Payload) tlsInit(st *State) *State { + log.Debug("TLS: no TLS connection in state yet, starting connection") + st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) + st.Conn = NewBuffConn(p.Data, st.Context) + st.TLS = tls.Server(st.Conn, &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + log.Debugf("TLS: ClientHello: %+v\n", ch) + st.ClientHello = ch + return nil, nil + }, + ClientAuth: tls.RequireAnyClientCert, + Certificates: certs, + }) + go func() { + err := st.TLS.HandshakeContext(st.Context) + if err != nil { + log.WithError(err).Debug("TLS: Handshake error") + // TODO: Send a NAK to the client + return + } + log.Debug("TLS: handshake done") + p.tlsHandshakeFinished(st) + }() + return st +} + +func (p *Payload) tlsHandshakeFinished(st *State) { cs := st.TLS.ConnectionState() label := "client EAP encryption" var context []byte