separate eap logic into protocol

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-21 00:53:18 +02:00
parent 19bb77638a
commit 1575b96262
10 changed files with 185 additions and 89 deletions

View File

@ -9,6 +9,7 @@ import (
type context struct { type context struct {
req *radius.Request req *radius.Request
state interface{} state interface{}
typeState map[protocol.Type]any
log *log.Entry log *log.Entry
settings interface{} settings interface{}
endStatus protocol.Status endStatus protocol.Status
@ -23,6 +24,10 @@ func (ctx context) ProtocolSettings() interface{} {
return ctx.settings return ctx.settings
} }
func (ctx *context) StateForProtocol(p protocol.Type) interface{} {
return ctx.typeState[p]
}
func (ctx *context) GetProtocolState() interface{} { func (ctx *context) GetProtocolState() interface{} {
return ctx.state return ctx.state
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/protocol/eap"
"goauthentik.io/internal/outpost/radius/eap/protocol/legacy_nak" "goauthentik.io/internal/outpost/radius/eap/protocol/legacy_nak"
"layeh.com/radius" "layeh.com/radius"
"layeh.com/radius/rfc2865" "layeh.com/radius/rfc2865"
@ -30,16 +31,20 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
} }
p.state = rst p.state = rst
rp, err := p.handleInner(r) rep, err := p.handleInner(r)
rp := &Packet{
eap: rep,
}
rres := r.Response(radius.CodeAccessReject) rres := r.Response(radius.CodeAccessReject)
if err == nil { if err == nil {
rres = p.endModifier(rres) rres = p.endModifier(rres)
switch rp.code { switch rp.eap.Code {
case CodeRequest: case protocol.CodeRequest:
rres.Code = radius.CodeAccessChallenge rres.Code = radius.CodeAccessChallenge
case CodeFailure: case protocol.CodeFailure:
rres.Code = radius.CodeAccessReject rres.Code = radius.CodeAccessReject
case CodeSuccess: case protocol.CodeSuccess:
rres.Code = radius.CodeAccessAccept rres.Code = radius.CodeAccessAccept
} }
} else { } else {
@ -54,7 +59,7 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
sendErrorResponse(w, r) sendErrorResponse(w, r)
return return
} }
log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.Payload)).Debug("EAP: encapsulated challenge") log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.eap.Payload)).Debug("EAP: encapsulated challenge")
rfc2869.EAPMessage_Set(rres, eapEncoded) rfc2869.EAPMessage_Set(rres, eapEncoded)
err = p.setMessageAuthenticator(rres) err = p.setMessageAuthenticator(rres)
if err != nil { if err != nil {
@ -68,30 +73,39 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
} }
} }
func (p *Packet) handleInner(r *radius.Request) (*Packet, error) { func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
st := p.stm.GetEAPState(p.state) st := p.stm.GetEAPState(p.state)
if st == nil { if st == nil {
log.Debug("EAP: blank state") log.Debug("EAP: blank state")
st = BlankState(p.stm.GetEAPSettings()) st = BlankState(p.stm.GetEAPSettings())
} }
// FIXME: Statically call Handle of root EAP packet to make its data accessible
ectx := &context{
state: st.TypeState[eap.TypeEAP],
log: log.WithField("type", fmt.Sprintf("%T", &eap.Payload{})),
}
p.eap.Handle(ectx)
st.TypeState[eap.TypeEAP] = ectx.GetProtocolState()
p.stm.SetEAPState(p.state, st)
nextChallengeToOffer, err := st.GetNextProtocol() nextChallengeToOffer, err := st.GetNextProtocol()
if err != nil { if err != nil {
return &Packet{ return &eap.Payload{
code: CodeFailure, Code: protocol.CodeFailure,
id: p.id, ID: p.eap.ID,
}, err }, err
} }
next := func() (*Packet, error) { next := func() (*eap.Payload, error) {
st.ProtocolIndex += 1 st.ProtocolIndex += 1
p.stm.SetEAPState(p.state, st) p.stm.SetEAPState(p.state, st)
return p.handleInner(r) return p.handleInner(r)
} }
if _, ok := p.Payload.(*legacy_nak.Payload); ok { if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok {
log.Debug("EAP: received NAK, trying next protocol") log.Debug("EAP: received NAK, trying next protocol")
p.Payload = nil p.eap.Payload = nil
return next() return next()
} }
@ -100,6 +114,7 @@ func (p *Packet) handleInner(r *radius.Request) (*Packet, error) {
ctx := &context{ ctx := &context{
req: r, req: r,
state: st.TypeState[np.Type()], state: st.TypeState[np.Type()],
typeState: st.TypeState,
log: log.WithField("type", fmt.Sprintf("%T", np)), log: log.WithField("type", fmt.Sprintf("%T", np)),
settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()], settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()],
} }
@ -119,11 +134,11 @@ func (p *Packet) handleInner(r *radius.Request) (*Packet, error) {
switch ctx.endStatus { switch ctx.endStatus {
case protocol.StatusSuccess: case protocol.StatusSuccess:
res.code = CodeSuccess res.Code = protocol.CodeSuccess
res.id -= 1 res.ID -= 1
case protocol.StatusError: case protocol.StatusError:
res.code = CodeFailure res.Code = protocol.CodeFailure
res.id -= 1 res.ID -= 1
case protocol.StatusNextProtocol: case protocol.StatusNextProtocol:
ctx.log.Debug("EAP: Protocol ended, starting next protocol") ctx.log.Debug("EAP: Protocol ended, starting next protocol")
return next() return next()
@ -132,18 +147,18 @@ func (p *Packet) handleInner(r *radius.Request) (*Packet, error) {
return res, nil return res, nil
} }
func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *Packet { func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *eap.Payload {
res := &Packet{ res := &eap.Payload{
code: CodeRequest, Code: protocol.CodeRequest,
id: p.id + 1, ID: p.eap.ID + 1,
msgType: np.Type(), MsgType: np.Type(),
} }
var payload any var payload any
if ctx.IsProtocolStart() { if ctx.IsProtocolStart() {
p.Payload = np p.eap.Payload = np
p.Payload.Decode(p.rawPayload) p.eap.Payload.Decode(p.eap.RawPayload)
} }
payload = p.Payload.Handle(ctx) payload = p.eap.Payload.Handle(ctx)
if payload != nil { if payload != nil {
res.Payload = payload.(protocol.Payload) res.Payload = payload.(protocol.Payload)
} }

View File

@ -1,40 +1,20 @@
package eap package eap
import ( import (
"encoding/binary"
"errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/debug"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/protocol/eap"
"layeh.com/radius" "layeh.com/radius"
) )
type Code uint8
const (
CodeRequest Code = 1
CodeResponse Code = 2
CodeSuccess Code = 3
CodeFailure Code = 4
)
type Packet struct { type Packet struct {
code Code eap *eap.Payload
id uint8
length uint16
msgType protocol.Type
rawPayload []byte
Payload protocol.Payload
stm StateManager stm StateManager
state string state string
endModifier func(p *radius.Packet) *radius.Packet endModifier func(p *radius.Packet) *radius.Packet
} }
type PayloadWriter struct{}
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) { func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) {
for _, cons := range stm.GetEAPSettings().Protocols { for _, cons := range stm.GetEAPSettings().Protocols {
if np := cons(); np.Type() == t { if np := cons(); np.Type() == t {
@ -46,28 +26,24 @@ func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) {
func Decode(stm StateManager, raw []byte) (*Packet, error) { func Decode(stm StateManager, raw []byte) (*Packet, error) {
packet := &Packet{ packet := &Packet{
eap: &eap.Payload{},
stm: stm, stm: stm,
endModifier: func(p *radius.Packet) *radius.Packet { endModifier: func(p *radius.Packet) *radius.Packet {
return p return p
}, },
} }
packet.code = Code(raw[0]) // FIXME: We're decoding twice here, first to get the msg type, then come back to assign the payload type
packet.id = raw[1] // then re-parse to parse the payload correctly
packet.length = binary.BigEndian.Uint16(raw[2:]) err := packet.eap.Decode(raw)
if packet.length != uint16(len(raw)) {
return nil, errors.New("mismatched packet length")
}
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
packet.msgType = protocol.Type(raw[4])
}
p, err := emptyPayload(stm, packet.msgType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.Payload = p p, err := emptyPayload(stm, packet.eap.MsgType)
packet.rawPayload = raw[5:] if err != nil {
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw") return nil, err
err = packet.Payload.Decode(raw[5:]) }
packet.eap.Payload = p
err = packet.eap.Decode(raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,20 +51,5 @@ func Decode(stm StateManager, raw []byte) (*Packet, error) {
} }
func (p *Packet) Encode() ([]byte, error) { func (p *Packet) Encode() ([]byte, error) {
buff := make([]byte, 4) return p.eap.Encode()
buff[0] = uint8(p.code)
buff[1] = uint8(p.id)
if p.Payload != nil {
payloadBuffer, err := p.Payload.Encode()
if err != nil {
return buff, err
}
if p.code == CodeRequest || p.code == CodeResponse {
buff = append(buff, uint8(p.msgType))
}
buff = append(buff, payloadBuffer...)
}
binary.BigEndian.PutUint16(buff[2:], uint16(len(buff)))
return buff, nil
} }

View File

@ -18,6 +18,8 @@ type Context interface {
Packet() *radius.Request Packet() *radius.Request
ProtocolSettings() interface{} ProtocolSettings() interface{}
StateForProtocol(p Type) interface{}
GetProtocolState() interface{} GetProtocolState() interface{}
SetProtocolState(interface{}) SetProtocolState(interface{})

View File

@ -0,0 +1,83 @@
package eap
import (
"encoding/binary"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/debug"
"goauthentik.io/internal/outpost/radius/eap/protocol"
)
const TypeEAP protocol.Type = 0
func Protocol() protocol.Payload {
return &Payload{}
}
type Payload struct {
Code protocol.Code
ID uint8
Length uint16
MsgType protocol.Type
Payload protocol.Payload
RawPayload []byte
}
func (ip *Payload) Type() protocol.Type {
return TypeEAP
}
func (ip *Payload) Offerable() bool {
return false
}
func (packet *Payload) Decode(raw []byte) error {
packet.Code = protocol.Code(raw[0])
packet.ID = raw[1]
packet.Length = binary.BigEndian.Uint16(raw[2:])
if packet.Length != uint16(len(raw)) {
return errors.New("mismatched packet length")
}
if len(raw) > 4 && (packet.Code == protocol.CodeRequest || packet.Code == protocol.CodeResponse) {
packet.MsgType = protocol.Type(raw[4])
}
packet.RawPayload = raw[5:]
if packet.Payload == nil {
return nil
}
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw")
err := packet.Payload.Decode(raw[5:])
if err != nil {
return err
}
return nil
}
func (p *Payload) Encode() ([]byte, error) {
buff := make([]byte, 4)
buff[0] = uint8(p.Code)
buff[1] = uint8(p.ID)
if p.Payload != nil {
payloadBuffer, err := p.Payload.Encode()
if err != nil {
return buff, err
}
if p.Code == protocol.CodeRequest || p.Code == protocol.CodeResponse {
buff = append(buff, uint8(p.MsgType))
}
buff = append(buff, payloadBuffer...)
}
binary.BigEndian.PutUint16(buff[2:], uint16(len(buff)))
return buff, nil
}
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
ctx.Log().Debug("EAP: Handle")
ctx.SetProtocolState(&State{
PacketID: ip.ID,
})
return nil
}

View File

@ -0,0 +1,5 @@
package eap
type State struct {
PacketID uint8
}

View File

@ -9,3 +9,12 @@ type Payload interface {
} }
type Type uint8 type Type uint8
type Code uint8
const (
CodeRequest Code = 1
CodeResponse Code = 2
CodeSuccess Code = 3
CodeFailure Code = 4
)

View File

@ -4,6 +4,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/debug"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/protocol/eap"
"goauthentik.io/internal/outpost/radius/eap/protocol/identity"
"goauthentik.io/internal/outpost/radius/eap/protocol/tls" "goauthentik.io/internal/outpost/radius/eap/protocol/tls"
) )
@ -11,11 +13,14 @@ const TypePEAP protocol.Type = 25
func Protocol() protocol.Payload { func Protocol() protocol.Payload {
return &tls.Payload{ return &tls.Payload{
Inner: &Payload{}, Inner: &Payload{
Inner: &eap.Payload{},
},
} }
} }
type Payload struct { type Payload struct {
Inner protocol.Payload
} }
func (p *Payload) Type() protocol.Type { func (p *Payload) Type() protocol.Type {
@ -33,7 +38,16 @@ func (p *Payload) Encode() ([]byte, error) {
} }
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
log.Debug("PEAP: Handle") eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State)
if !ctx.IsProtocolStart() {
ctx.Log().Debug("PEAP: Protocol start")
return &eap.Payload{
Code: protocol.CodeRequest,
ID: eapState.PacketID,
MsgType: identity.TypeIdentity,
Payload: &identity.Payload{},
}
}
return &Payload{} return &Payload{}
} }

View File

@ -4,18 +4,19 @@ import (
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
) )
func (p *Payload) innerHandler(ctx protocol.Context) *Payload { func (p *Payload) innerHandler(ctx protocol.Context) {
// p.st.TLS.read // p.st.TLS.read
// d, _ := io.ReadAll(p.st.TLS) // d, _ := io.ReadAll(p.st.TLS)
err := p.Inner.Decode([]byte{}) err := p.Inner.Decode([]byte{})
if err != nil { if err != nil {
ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol") ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol")
ctx.EndInnerProtocol(protocol.StatusError, nil) ctx.EndInnerProtocol(protocol.StatusError, nil)
return nil return
} }
pl := p.Inner.Handle(ctx) pl := p.Inner.Handle(ctx)
enc, err := pl.Encode() enc, err := pl.Encode()
return &Payload{ p.st.TLS.Write(enc)
Data: enc, // return &Payload{
} // Data: enc,
// }
} }

View File

@ -124,7 +124,8 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
defer p.st.ContextCancel() defer p.st.ContextCancel()
if p.Inner != nil { if p.Inner != nil {
ctx.Log().Debug("TLS: Handshake is done, delegating to inner protocol") ctx.Log().Debug("TLS: Handshake is done, delegating to inner protocol")
return p.innerHandler(ctx) p.innerHandler(ctx)
return p.startChunkedTransfer(p.st.Conn.OutboundData())
} }
// If we don't have a final status from the handshake finished function, stall for time // If we don't have a final status from the handshake finished function, stall for time
pst, _ := retry.DoWithData( pst, _ := retry.DoWithData(