Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-14 14:13:20 +02:00
parent ba8f137885
commit d7cb0b3ea1
5 changed files with 66 additions and 29 deletions

View File

@ -6,6 +6,7 @@ import (
"encoding/base64" "encoding/base64"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/tls" "goauthentik.io/internal/outpost/radius/eap/tls"
"layeh.com/radius" "layeh.com/radius"
"layeh.com/radius/rfc2865" "layeh.com/radius/rfc2865"
@ -28,6 +29,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
res, newState := p.GetChallengeForType(st, nextChallengeToOffer) res, newState := p.GetChallengeForType(st, nextChallengeToOffer)
stm.SetEAPState(rst, newState) stm.SetEAPState(rst, newState)
log.Debug("EAP: encapsulating challenge")
rres := r.Response(radius.CodeAccessChallenge) rres := r.Response(radius.CodeAccessChallenge)
rfc2865.State_SetString(rres, rst) rfc2865.State_SetString(rres, rst)
eapEncoded, err := res.Encode() eapEncoded, err := res.Encode()
@ -36,7 +38,6 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
} }
rfc2869.EAPMessage_Set(rres, eapEncoded) rfc2869.EAPMessage_Set(rres, eapEncoded)
p.setMessageAuthenticator(rres) p.setMessageAuthenticator(rres)
// debug.DebugPacket(rres)
err = w.Write(rres) err = w.Write(rres)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -4,7 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/debug"
"goauthentik.io/internal/outpost/radius/eap/tls" "goauthentik.io/internal/outpost/radius/eap/tls"
) )
@ -63,7 +63,7 @@ func Decode(raw []byte) (*Packet, error) {
} }
packet.Payload = emptyPayload(packet.msgType) packet.Payload = emptyPayload(packet.msgType)
packet.rawPayload = raw[5:] packet.rawPayload = raw[5:]
logrus.WithField("raw", debug.FormatBytes(raw[5:])).Debug("EAP decode raw") log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw")
err := packet.Payload.Decode(raw[5:]) err := packet.Payload.Decode(raw[5:])
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"net" "net"
"time" "time"
log "github.com/sirupsen/logrus"
) )
type TLSConnection struct { type TLSConnection struct {
@ -19,17 +21,38 @@ func NewTLSConnection(initialData []byte) TLSConnection {
return c return c
} }
func (conn TLSConnection) Read(p []byte) (int, error) { return conn.reader.Read(p) } func (conn TLSConnection) TLSData() []byte {
return conn.writer.Bytes()
}
func (conn TLSConnection) UpdateData(data []byte) {
conn.reader.Reset()
conn.reader.Write(data)
}
// ----
func (conn TLSConnection) Read(p []byte) (int, error) {
log.Debugf("TLS(buffer): Read: %d from %d", len(p), conn.reader.Len())
for {
n, err := conn.reader.Read(p)
if n == 0 {
log.Debug("TLS(buffer): Attempted read from empty buffer, stalling...")
time.Sleep(1 * time.Second)
continue
}
return n, err
}
}
func (conn TLSConnection) Write(p []byte) (int, error) { func (conn TLSConnection) Write(p []byte) (int, error) {
log.Debugf("TLS(buffer): Write: %d", len(p))
return conn.writer.Write(p) return conn.writer.Write(p)
} }
func (conn TLSConnection) Close() error { return nil } func (conn TLSConnection) Close() error { return nil }
func (conn TLSConnection) LocalAddr() net.Addr { return nil } func (conn TLSConnection) LocalAddr() net.Addr { return nil }
func (conn TLSConnection) RemoteAddr() net.Addr { return nil } func (conn TLSConnection) RemoteAddr() net.Addr { return nil }
func (conn TLSConnection) SetDeadline(t time.Time) error { return nil } func (conn TLSConnection) SetDeadline(t time.Time) error { return nil }
func (conn TLSConnection) SetReadDeadline(t time.Time) error { return nil } func (conn TLSConnection) SetReadDeadline(t time.Time) error { return nil }
func (conn TLSConnection) SetWriteDeadline(t time.Time) error { return nil } func (conn TLSConnection) SetWriteDeadline(t time.Time) error { return nil }
func (conn TLSConnection) TLSData() []byte {
return conn.writer.Bytes()
}

View File

@ -1,10 +1,12 @@
package tls package tls
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
"errors" "errors"
"slices" "slices"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/debug"
@ -27,6 +29,7 @@ func (p *Payload) Decode(raw []byte) error {
} else { } else {
p.Data = raw[1:] p.Data = raw[1:]
} }
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw")
return nil return nil
} }
@ -68,69 +71,74 @@ func (p *Payload) Handle(stt any) (*Payload, State) {
stt = NewState() stt = NewState()
} }
st := stt.(State) st := stt.(State)
log.WithField("flags", p.Flags).Debug("Got TLS Packet")
if !st.HasStarted { if !st.HasStarted {
st.HasStarted = true st.HasStarted = true
return &Payload{ return &Payload{
Flags: FlagTLSStart, Flags: FlagTLSStart,
}, st }, st
} }
if st.HasMore() {
return p.sendNextChunk(st)
}
log.WithField("raw", debug.FormatBytes(p.Data)).Debug("TLS: Decode raw")
tc := NewTLSConnection(p.Data)
if st.TLS == nil { if st.TLS == nil {
log.Debug("no TLS connection in state yet, starting connection") log.Debug("TLS: no TLS connection in state yet, starting connection")
st.TLS = tls.Server(tc, &tls.Config{ st.Conn = NewTLSConnection(p.Data)
st.TLS = tls.Server(st.Conn, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
log.Debugf("%+v\n", argHello) log.Debugf("TLS: ClientHello: %+v\n", argHello)
return nil, nil return nil, nil
}, },
ClientAuth: tls.RequireAnyClientCert, ClientAuth: tls.RequireAnyClientCert,
Certificates: certs, Certificates: certs,
}) })
err := st.TLS.Handshake() st.Context, _ = context.WithTimeout(context.Background(), 30*time.Second)
log.WithError(err).Debug("TLS: Handshake error") go func() {
err := st.TLS.HandshakeContext(st.Context)
if err != nil {
log.WithError(err).Debug("TLS: Handshake error")
}
}()
} else if len(p.Data) > 0 {
log.Debug("TLS: Updating buffer with new TLS data from packet")
st.Conn.UpdateData(p.Data)
} }
return p.sendDataChunked(tc.TLSData(), st) if st.HasMore() {
return p.sendNextChunk(st)
}
return p.startChunkedTransfer(st.Conn.TLSData(), st)
} }
const maxChunkSize = 1000 const maxChunkSize = 1000
func (p *Payload) sendDataChunked(data []byte, st State) (*Payload, State) { func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) {
flags := FlagLengthIncluded flags := FlagLengthIncluded
var dataToSend []byte var dataToSend []byte
if len(data) > maxChunkSize { if len(data) > maxChunkSize {
log.WithField("length", len(data)).Debug("Data needs to be chunked") log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
flags += FlagMoreFragments flags += FlagMoreFragments
dataToSend = data[:maxChunkSize] dataToSend = data[:maxChunkSize]
remainingData := data[maxChunkSize:] remainingData := data[maxChunkSize:]
// Chunk remaining data into correct chunks and add them to the list // Chunk remaining data into correct chunks and add them to the list
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(remainingData, maxChunkSize))...) st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(remainingData, maxChunkSize))...)
st.TotalPayloadSize = len(st.RemainingChunks) * maxChunkSize st.TotalPayloadSize = len(data)
} else { } else {
dataToSend = data dataToSend = data
} }
return &Payload{ return &Payload{
Flags: flags, Flags: flags,
Length: uint32(len(data) + 5), Length: uint32(st.TotalPayloadSize),
Data: dataToSend, Data: dataToSend,
}, st }, st
} }
func (p *Payload) sendNextChunk(st State) (*Payload, State) { func (p *Payload) sendNextChunk(st State) (*Payload, State) {
log.Debug("Sending next chunk") log.Debug("TLS: Sending next chunk")
nextChunk := st.RemainingChunks[0] nextChunk := st.RemainingChunks[0]
st.RemainingChunks = st.RemainingChunks[1:] st.RemainingChunks = st.RemainingChunks[1:]
flags := FlagLengthIncluded flags := FlagLengthIncluded
if st.HasMore() { if st.HasMore() {
log.WithField("chunks", len(st.RemainingChunks)).Debug("More chunks left") log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left")
flags += FlagMoreFragments flags += FlagMoreFragments
} }
log.WithField("length", st.TotalPayloadSize).Debug("Total payload size") log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size")
return &Payload{ return &Payload{
Flags: flags, Flags: flags,
Length: uint32(st.TotalPayloadSize), Length: uint32(st.TotalPayloadSize),

View File

@ -1,12 +1,17 @@
package tls package tls
import "crypto/tls" import (
"context"
"crypto/tls"
)
type State struct { type State struct {
HasStarted bool HasStarted bool
RemainingChunks [][]byte RemainingChunks [][]byte
TotalPayloadSize int TotalPayloadSize int
TLS *tls.Conn TLS *tls.Conn
Conn TLSConnection
Context context.Context
} }
func NewState() State { func NewState() State {