fix a bunch more

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-15 01:21:12 +02:00
parent a71532b3e3
commit df21e678d6
5 changed files with 89 additions and 24 deletions

View File

@ -20,6 +20,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
}
st := stm.GetEAPState(rst)
if st == nil {
log.Debug("EAP: blank state")
st = BlankState(stm.GetEAPSettings())
}
if len(st.ChallengesToOffer) < 1 {

View File

@ -11,17 +11,19 @@ import (
type TLSConnection struct {
reader *bytes.Buffer
writer *bytes.Buffer
bufferIncomingBytesCount uint32
}
func NewTLSConnection(initialData []byte) TLSConnection {
c := TLSConnection{
func NewTLSConnection(initialData []byte) *TLSConnection {
c := &TLSConnection{
reader: bytes.NewBuffer(initialData),
writer: bytes.NewBuffer([]byte{}),
}
return c
}
func (conn TLSConnection) GetData() []byte {
func (conn TLSConnection) OutboundData() []byte {
for {
b := conn.writer.Bytes()
if len(b) < 1 {
@ -34,21 +36,50 @@ func (conn TLSConnection) GetData() []byte {
}
func (conn TLSConnection) UpdateData(data []byte) {
conn.reader.Reset()
conn.reader.Write(data)
if conn.bufferIncomingBytesCount > 0 && conn.reader.Len() == int(conn.bufferIncomingBytesCount) {
conn.bufferIncomingBytesCount = 0
}
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.reader.Len(), conn.bufferIncomingBytesCount)
}
// ----
// func (conn TLSConnection) Reset() {
// log.Debug("TLS(buffer): reset")
// conn.reader.Reset()
// conn.writer.Reset()
// }
func (conn TLSConnection) Read(p []byte) (int, error) {
log.Debugf("TLS(buffer): Read: %d from %d", len(p), conn.reader.Len())
func (conn TLSConnection) NeedsMoreData() bool {
if conn.bufferIncomingBytesCount > 0 {
return conn.reader.Len() < int(conn.bufferIncomingBytesCount)
}
return false
}
// func (conn TLSConnection) WaitForAttemptedRead() int {
// for {
// // log.Debug("TLS(buffer): waiting for attempted read")
// if conn.missingBytes == 0 {
// continue
// }
// return conn.missingBytes
// }
// }
func (conn *TLSConnection) Read(p []byte) (int, error) {
for {
n, err := conn.reader.Read(p)
if n == 0 {
log.Debug("TLS(buffer): Attempted read from empty buffer, stalling...")
time.Sleep(1 * time.Second)
log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p))
time.Sleep(500 * time.Millisecond)
continue
}
if conn.reader.Len() < int(conn.bufferIncomingBytesCount) {
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.reader.Len()-int(conn.bufferIncomingBytesCount))
time.Sleep(500 * time.Millisecond)
continue
}
log.Debugf("TLS(buffer): Read: %d from %d", len(p), n)
return n, err
}
}

View File

@ -12,6 +12,8 @@ import (
"goauthentik.io/internal/outpost/radius/eap/debug"
)
const maxChunkSize = 1000
type Payload struct {
Flags Flag
Length uint32
@ -20,16 +22,17 @@ type Payload struct {
func (p *Payload) Decode(raw []byte) error {
p.Flags = Flag(raw[0])
raw = raw[1:]
if p.Flags&FlagLengthIncluded != 0 {
if len(raw) < 4 {
return errors.New("invalid size")
}
p.Length = binary.BigEndian.Uint32(raw)
p.Data = raw[5:]
p.Data = raw[4:]
} else {
p.Data = raw[1:]
p.Data = raw[0:]
}
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw")
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw")
return nil
}
@ -66,12 +69,14 @@ func init() {
certs = append(certs, cert)
}
func (p *Payload) Handle(stt any) (*Payload, State) {
func (p *Payload) Handle(stt any) (*Payload, *State) {
if stt == nil {
log.Debug("TLS: new state")
stt = NewState()
}
st := stt.(State)
st := stt.(*State)
if !st.HasStarted {
log.Debug("TLS: handshake starting")
st.HasStarted = true
return &Payload{
Flags: FlagTLSStart,
@ -89,8 +94,10 @@ func (p *Payload) Handle(stt any) (*Payload, State) {
ClientAuth: tls.RequireAnyClientCert,
Certificates: certs,
})
st.Context, _ = context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
st.Context = ctx
go func() {
defer cancel()
err := st.TLS.HandshakeContext(st.Context)
if err != nil {
log.WithError(err).Debug("TLS: Handshake error")
@ -98,17 +105,35 @@ func (p *Payload) Handle(stt any) (*Payload, State) {
}()
} else if len(p.Data) > 0 {
log.Debug("TLS: Updating buffer with new TLS data from packet")
if p.Flags&FlagLengthMore != 0 && st.Conn.bufferIncomingBytesCount == 0 {
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
st.Conn.bufferIncomingBytesCount = p.Length
}
st.Conn.UpdateData(p.Data)
return &Payload{
Flags: FlagNone,
Length: 0,
Data: []byte{},
}, st
}
// If we need more data, send the client the go-ahead
if st.Conn.NeedsMoreData() {
return &Payload{
Flags: FlagNone,
Length: 0,
Data: []byte{},
}, st
}
if st.HasMore() {
return p.sendNextChunk(st)
}
return p.startChunkedTransfer(st.Conn.GetData(), st)
if len(st.Conn.OutboundData()) > 0 {
return p.startChunkedTransfer(st.Conn.OutboundData(), st)
}
panic("we shouldn't get here")
}
const maxChunkSize = 1000
func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) {
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) {
flags := FlagLengthIncluded
var dataToSend []byte
if len(data) > maxChunkSize {
@ -129,14 +154,21 @@ func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State)
}, st
}
func (p *Payload) sendNextChunk(st State) (*Payload, State) {
log.Debug("TLS: Sending next chunk")
func (p *Payload) sendNextChunk(st *State) (*Payload, *State) {
nextChunk := st.RemainingChunks[0]
log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk")
st.RemainingChunks = st.RemainingChunks[1:]
flags := FlagLengthIncluded
if st.HasMore() {
log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left")
flags += FlagMoreFragments
} else {
// Last chunk, reset the connection buffers and pending payload size
defer func() {
log.Debug("TLS: Sent last chunk")
st.Conn.reader.Reset()
st.TotalPayloadSize = 0
}()
}
log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size")
return &Payload{

View File

@ -10,12 +10,12 @@ type State struct {
RemainingChunks [][]byte
TotalPayloadSize int
TLS *tls.Conn
Conn TLSConnection
Conn *TLSConnection
Context context.Context
}
func NewState() State {
return State{
func NewState() *State {
return &State{
RemainingChunks: make([][]byte, 0),
}
}

View File

@ -60,6 +60,7 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
"code": r.Code.String(),
"request": rid,
"ip": host,
"id": r.Identifier,
})
selectedApp := ""
defer func() {