@ -20,6 +20,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
||||
}
|
||||
st := stm.GetEAPState(rst)
|
||||
if st == nil {
|
||||
log.Debug("EAP: blank state")
|
||||
st = BlankState(stm.GetEAPSettings())
|
||||
}
|
||||
if len(st.ChallengesToOffer) < 1 {
|
||||
|
||||
@ -11,17 +11,19 @@ import (
|
||||
type TLSConnection struct {
|
||||
reader *bytes.Buffer
|
||||
writer *bytes.Buffer
|
||||
|
||||
bufferIncomingBytesCount uint32
|
||||
}
|
||||
|
||||
func NewTLSConnection(initialData []byte) TLSConnection {
|
||||
c := TLSConnection{
|
||||
func NewTLSConnection(initialData []byte) *TLSConnection {
|
||||
c := &TLSConnection{
|
||||
reader: bytes.NewBuffer(initialData),
|
||||
writer: bytes.NewBuffer([]byte{}),
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (conn TLSConnection) GetData() []byte {
|
||||
func (conn TLSConnection) OutboundData() []byte {
|
||||
for {
|
||||
b := conn.writer.Bytes()
|
||||
if len(b) < 1 {
|
||||
@ -34,21 +36,50 @@ func (conn TLSConnection) GetData() []byte {
|
||||
}
|
||||
|
||||
func (conn TLSConnection) UpdateData(data []byte) {
|
||||
conn.reader.Reset()
|
||||
conn.reader.Write(data)
|
||||
if conn.bufferIncomingBytesCount > 0 && conn.reader.Len() == int(conn.bufferIncomingBytesCount) {
|
||||
conn.bufferIncomingBytesCount = 0
|
||||
}
|
||||
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.reader.Len(), conn.bufferIncomingBytesCount)
|
||||
}
|
||||
|
||||
// ----
|
||||
// func (conn TLSConnection) Reset() {
|
||||
// log.Debug("TLS(buffer): reset")
|
||||
// conn.reader.Reset()
|
||||
// conn.writer.Reset()
|
||||
// }
|
||||
|
||||
func (conn TLSConnection) Read(p []byte) (int, error) {
|
||||
log.Debugf("TLS(buffer): Read: %d from %d", len(p), conn.reader.Len())
|
||||
func (conn TLSConnection) NeedsMoreData() bool {
|
||||
if conn.bufferIncomingBytesCount > 0 {
|
||||
return conn.reader.Len() < int(conn.bufferIncomingBytesCount)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// func (conn TLSConnection) WaitForAttemptedRead() int {
|
||||
// for {
|
||||
// // log.Debug("TLS(buffer): waiting for attempted read")
|
||||
// if conn.missingBytes == 0 {
|
||||
// continue
|
||||
// }
|
||||
// return conn.missingBytes
|
||||
// }
|
||||
// }
|
||||
|
||||
func (conn *TLSConnection) Read(p []byte) (int, error) {
|
||||
for {
|
||||
n, err := conn.reader.Read(p)
|
||||
if n == 0 {
|
||||
log.Debug("TLS(buffer): Attempted read from empty buffer, stalling...")
|
||||
time.Sleep(1 * time.Second)
|
||||
log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if conn.reader.Len() < int(conn.bufferIncomingBytesCount) {
|
||||
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.reader.Len()-int(conn.bufferIncomingBytesCount))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
log.Debugf("TLS(buffer): Read: %d from %d", len(p), n)
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,8 @@ import (
|
||||
"goauthentik.io/internal/outpost/radius/eap/debug"
|
||||
)
|
||||
|
||||
const maxChunkSize = 1000
|
||||
|
||||
type Payload struct {
|
||||
Flags Flag
|
||||
Length uint32
|
||||
@ -20,16 +22,17 @@ type Payload struct {
|
||||
|
||||
func (p *Payload) Decode(raw []byte) error {
|
||||
p.Flags = Flag(raw[0])
|
||||
raw = raw[1:]
|
||||
if p.Flags&FlagLengthIncluded != 0 {
|
||||
if len(raw) < 4 {
|
||||
return errors.New("invalid size")
|
||||
}
|
||||
p.Length = binary.BigEndian.Uint32(raw)
|
||||
p.Data = raw[5:]
|
||||
p.Data = raw[4:]
|
||||
} else {
|
||||
p.Data = raw[1:]
|
||||
p.Data = raw[0:]
|
||||
}
|
||||
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw")
|
||||
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -66,12 +69,14 @@ func init() {
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
|
||||
func (p *Payload) Handle(stt any) (*Payload, State) {
|
||||
func (p *Payload) Handle(stt any) (*Payload, *State) {
|
||||
if stt == nil {
|
||||
log.Debug("TLS: new state")
|
||||
stt = NewState()
|
||||
}
|
||||
st := stt.(State)
|
||||
st := stt.(*State)
|
||||
if !st.HasStarted {
|
||||
log.Debug("TLS: handshake starting")
|
||||
st.HasStarted = true
|
||||
return &Payload{
|
||||
Flags: FlagTLSStart,
|
||||
@ -89,8 +94,10 @@ func (p *Payload) Handle(stt any) (*Payload, State) {
|
||||
ClientAuth: tls.RequireAnyClientCert,
|
||||
Certificates: certs,
|
||||
})
|
||||
st.Context, _ = context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
st.Context = ctx
|
||||
go func() {
|
||||
defer cancel()
|
||||
err := st.TLS.HandshakeContext(st.Context)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("TLS: Handshake error")
|
||||
@ -98,17 +105,35 @@ func (p *Payload) Handle(stt any) (*Payload, State) {
|
||||
}()
|
||||
} else if len(p.Data) > 0 {
|
||||
log.Debug("TLS: Updating buffer with new TLS data from packet")
|
||||
if p.Flags&FlagLengthMore != 0 && st.Conn.bufferIncomingBytesCount == 0 {
|
||||
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
|
||||
st.Conn.bufferIncomingBytesCount = p.Length
|
||||
}
|
||||
st.Conn.UpdateData(p.Data)
|
||||
return &Payload{
|
||||
Flags: FlagNone,
|
||||
Length: 0,
|
||||
Data: []byte{},
|
||||
}, st
|
||||
}
|
||||
// If we need more data, send the client the go-ahead
|
||||
if st.Conn.NeedsMoreData() {
|
||||
return &Payload{
|
||||
Flags: FlagNone,
|
||||
Length: 0,
|
||||
Data: []byte{},
|
||||
}, st
|
||||
}
|
||||
if st.HasMore() {
|
||||
return p.sendNextChunk(st)
|
||||
}
|
||||
return p.startChunkedTransfer(st.Conn.GetData(), st)
|
||||
if len(st.Conn.OutboundData()) > 0 {
|
||||
return p.startChunkedTransfer(st.Conn.OutboundData(), st)
|
||||
}
|
||||
panic("we shouldn't get here")
|
||||
}
|
||||
|
||||
const maxChunkSize = 1000
|
||||
|
||||
func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) {
|
||||
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) {
|
||||
flags := FlagLengthIncluded
|
||||
var dataToSend []byte
|
||||
if len(data) > maxChunkSize {
|
||||
@ -129,14 +154,21 @@ func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State)
|
||||
}, st
|
||||
}
|
||||
|
||||
func (p *Payload) sendNextChunk(st State) (*Payload, State) {
|
||||
log.Debug("TLS: Sending next chunk")
|
||||
func (p *Payload) sendNextChunk(st *State) (*Payload, *State) {
|
||||
nextChunk := st.RemainingChunks[0]
|
||||
log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk")
|
||||
st.RemainingChunks = st.RemainingChunks[1:]
|
||||
flags := FlagLengthIncluded
|
||||
if st.HasMore() {
|
||||
log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left")
|
||||
flags += FlagMoreFragments
|
||||
} else {
|
||||
// Last chunk, reset the connection buffers and pending payload size
|
||||
defer func() {
|
||||
log.Debug("TLS: Sent last chunk")
|
||||
st.Conn.reader.Reset()
|
||||
st.TotalPayloadSize = 0
|
||||
}()
|
||||
}
|
||||
log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size")
|
||||
return &Payload{
|
||||
|
||||
@ -10,12 +10,12 @@ type State struct {
|
||||
RemainingChunks [][]byte
|
||||
TotalPayloadSize int
|
||||
TLS *tls.Conn
|
||||
Conn TLSConnection
|
||||
Conn *TLSConnection
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
func NewState() State {
|
||||
return State{
|
||||
func NewState() *State {
|
||||
return &State{
|
||||
RemainingChunks: make([][]byte, 0),
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,6 +60,7 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
|
||||
"code": r.Code.String(),
|
||||
"request": rid,
|
||||
"ip": host,
|
||||
"id": r.Identifier,
|
||||
})
|
||||
selectedApp := ""
|
||||
defer func() {
|
||||
|
||||
Reference in New Issue
Block a user