From b6686cff14393eb688b70574d191ee3d7687db34 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Tue, 20 May 2025 22:39:14 +0200 Subject: [PATCH] refactor v1, start support for more protocols and implement nak Signed-off-by: Jens Langhammer --- internal/outpost/radius/eap/context.go | 14 +- internal/outpost/radius/eap/handler.go | 137 +++++++++++------- .../outpost/radius/eap/identity/payload.go | 37 +++++ .../outpost/radius/eap/legacy_nak/payload.go | 37 +++++ internal/outpost/radius/eap/packet.go | 39 +++-- .../outpost/radius/eap/payload_identity.go | 14 -- .../outpost/radius/eap/protocol/context.go | 4 +- .../outpost/radius/eap/protocol/packet.go | 8 +- internal/outpost/radius/eap/state.go | 24 ++- internal/outpost/radius/eap/tls/payload.go | 23 ++- internal/outpost/radius/eap/tls/state.go | 1 - .../outpost/radius/handle_access_request.go | 20 ++- 12 files changed, 252 insertions(+), 106 deletions(-) create mode 100644 internal/outpost/radius/eap/identity/payload.go create mode 100644 internal/outpost/radius/eap/legacy_nak/payload.go delete mode 100644 internal/outpost/radius/eap/payload_identity.go diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go index 8e2cc26d97..e680b15cb0 100644 --- a/internal/outpost/radius/eap/context.go +++ b/internal/outpost/radius/eap/context.go @@ -23,10 +23,7 @@ func (ctx context) ProtocolSettings() interface{} { return ctx.settings } -func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} { - if ctx.state == nil { - ctx.state = def(ctx) - } +func (ctx *context) GetProtocolState() interface{} { return ctx.state } @@ -34,11 +31,20 @@ func (ctx *context) SetProtocolState(st interface{}) { ctx.state = st } +func (ctx *context) IsProtocolStart() bool { + return ctx.state == nil +} + func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { if ctx.endStatus != protocol.StatusUnknown { return } ctx.endStatus = st + if mf == nil { + mf = func(p *radius.Packet) *radius.Packet { + return p + } + } ctx.endModifier = mf } diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 49377663e2..2f965d719f 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -4,11 +4,12 @@ import ( "crypto/hmac" "crypto/md5" "encoding/base64" + "fmt" "github.com/gorilla/securecookie" log "github.com/sirupsen/logrus" + "goauthentik.io/internal/outpost/radius/eap/legacy_nak" "goauthentik.io/internal/outpost/radius/eap/protocol" - "goauthentik.io/internal/outpost/radius/eap/tls" "layeh.com/radius" "layeh.com/radius/rfc2865" "layeh.com/radius/rfc2869" @@ -22,66 +23,42 @@ func sendErrorResponse(w radius.ResponseWriter, r *radius.Request) { } } -func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Request) { +func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) { rst := rfc2865.State_GetString(r.Packet) if rst == "" { rst = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(12)) } - st := stm.GetEAPState(rst) - if st == nil { - log.Debug("EAP: blank state") - st = BlankState(stm.GetEAPSettings()) - } - if len(st.ChallengesToOffer) < 1 { - log.Error("No more challenges to offer") - sendErrorResponse(w, r) - return - } - nextChallengeToOffer := st.ChallengesToOffer[0] + p.state = rst - ctx := &context{ - req: r, - state: st.TypeState[nextChallengeToOffer], - log: log.WithField("type", nextChallengeToOffer), - settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer], - } - - res := p.GetChallengeForType(ctx, nextChallengeToOffer) - st.TypeState[nextChallengeToOffer] = ctx.GetProtocolState(nil) - stm.SetEAPState(rst, st) - - rres := r.Response(radius.CodeAccessChallenge) - switch ctx.endStatus { - case protocol.StatusSuccess: - res.code = CodeSuccess - res.id -= 1 - rres = ctx.endModifier(rres) - st.ChallengesToOffer = st.ChallengesToOffer[1:] - if len(st.ChallengesToOffer) < 1 { + rp, err := p.handleInner(r) + rres := r.Response(radius.CodeAccessReject) + if err == nil { + rres = p.endModifier(rres) + switch rp.code { + case CodeFailure: + rres.Code = radius.CodeAccessReject + case CodeSuccess: rres.Code = radius.CodeAccessAccept } - case protocol.StatusError: - res.code = CodeFailure - res.id -= 1 - st.ChallengesToOffer = st.ChallengesToOffer[1:] - rres = ctx.endModifier(rres) - if len(st.ChallengesToOffer) < 1 { - rres.Code = radius.CodeAccessReject - } - case protocol.StatusUnknown: + } else { + rres.Code = radius.CodeAccessReject + log.WithError(err).Debug("Rejecting request") } - rfc2865.State_SetString(rres, rst) - eapEncoded, err := res.Encode() + + rfc2865.State_SetString(rres, p.state) + eapEncoded, err := rp.Encode() if err != nil { log.WithError(err).Warning("failed to encode response") sendErrorResponse(w, r) + return } - log.WithField("length", len(eapEncoded)).Debug("EAP: encapsulated challenge") + log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.Payload)).Debug("EAP: encapsulated challenge") rfc2869.EAPMessage_Set(rres, eapEncoded) err = p.setMessageAuthenticator(rres) if err != nil { log.WithError(err).Warning("failed to send message authenticator") sendErrorResponse(w, r) + return } err = w.Write(rres) if err != nil { @@ -89,21 +66,75 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Req } } -func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet { +func (p *Packet) handleInner(r *radius.Request) (*Packet, error) { + st := p.stm.GetEAPState(p.state) + if st == nil { + log.Debug("EAP: blank state") + st = BlankState(p.stm.GetEAPSettings()) + } + + nextChallengeToOffer, err := st.GetNextProtocol() + if err != nil { + return &Packet{ + code: CodeFailure, + id: p.id, + }, err + } + + if _, ok := p.Payload.(*legacy_nak.Payload); ok { + log.Debug("EAP: received NAK, trying next protocol") + st.ProtocolIndex += 1 + p.stm.SetEAPState(p.state, st) + return p.handleInner(r) + } + + np, _ := emptyPayload(p.stm, nextChallengeToOffer) + + ctx := &context{ + req: r, + state: st.TypeState[np.Type()], + log: log.WithField("type", fmt.Sprintf("%T", np)), + settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()], + } + ctx.log.Debug("EAP: Passing to protocol") + + res := p.GetChallengeForType(ctx, np) + st.TypeState[np.Type()] = ctx.GetProtocolState() + p.stm.SetEAPState(p.state, st) + + if ctx.endModifier != nil { + p.endModifier = ctx.endModifier + } + + switch ctx.endStatus { + case protocol.StatusSuccess: + res.code = CodeSuccess + res.id -= 1 + case protocol.StatusError: + res.code = CodeFailure + res.id -= 1 + case protocol.StatusNextProtocol: + ctx.log.Debug("EAP: Protocol ended, starting next protocol") + st.ProtocolIndex += 1 + p.stm.SetEAPState(p.state, st) + return p.handleInner(r) + case protocol.StatusUnknown: + } + return res, nil +} + +func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *Packet { res := &Packet{ code: CodeRequest, id: p.id + 1, - msgType: t, + msgType: np.Type(), } var payload any - switch t { - case tls.TypeTLS: - if _, ok := p.Payload.(*tls.Payload); !ok { - p.Payload = &tls.Payload{} - p.Payload.Decode(p.rawPayload) - } - payload = p.Payload.(*tls.Payload).Handle(ctx) + if ctx.IsProtocolStart() { + p.Payload = np + p.Payload.Decode(p.rawPayload) } + payload = p.Payload.Handle(ctx) if payload != nil { res.Payload = payload.(protocol.Payload) } diff --git a/internal/outpost/radius/eap/identity/payload.go b/internal/outpost/radius/eap/identity/payload.go new file mode 100644 index 0000000000..d9eda677b4 --- /dev/null +++ b/internal/outpost/radius/eap/identity/payload.go @@ -0,0 +1,37 @@ +package identity + +import "goauthentik.io/internal/outpost/radius/eap/protocol" + +const TypeIdentity protocol.Type = 1 + +func Protocol() protocol.Payload { + return &Payload{} +} + +type Payload struct { + Identity string +} + +func (ip *Payload) Type() protocol.Type { + return TypeIdentity +} + +func (ip *Payload) Decode(raw []byte) error { + ip.Identity = string(raw) + return nil +} + +func (ip *Payload) Encode() ([]byte, error) { + return []byte{}, nil +} + +func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload { + if ctx.IsProtocolStart() { + ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil) + } + return nil +} + +func (ip *Payload) Offerable() bool { + return false +} diff --git a/internal/outpost/radius/eap/legacy_nak/payload.go b/internal/outpost/radius/eap/legacy_nak/payload.go new file mode 100644 index 0000000000..697adb6973 --- /dev/null +++ b/internal/outpost/radius/eap/legacy_nak/payload.go @@ -0,0 +1,37 @@ +package legacy_nak + +import "goauthentik.io/internal/outpost/radius/eap/protocol" + +const TypeLegacyNAK protocol.Type = 3 + +func Protocol() protocol.Payload { + return &Payload{} +} + +type Payload struct { + DesiredType protocol.Type +} + +func (ln *Payload) Type() protocol.Type { + return TypeLegacyNAK +} + +func (ln *Payload) Decode(raw []byte) error { + ln.DesiredType = protocol.Type(raw[0]) + return nil +} + +func (ln *Payload) Encode() ([]byte, error) { + return []byte{byte(ln.DesiredType)}, nil +} + +func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload { + if ctx.IsProtocolStart() { + ctx.EndInnerProtocol(protocol.StatusError, nil) + } + return nil +} + +func (ln *Payload) Offerable() bool { + return false +} diff --git a/internal/outpost/radius/eap/packet.go b/internal/outpost/radius/eap/packet.go index ebeb45bb2c..46e6b7bf4c 100644 --- a/internal/outpost/radius/eap/packet.go +++ b/internal/outpost/radius/eap/packet.go @@ -3,11 +3,12 @@ package eap import ( "encoding/binary" "errors" + "fmt" log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/protocol" - "goauthentik.io/internal/outpost/radius/eap/tls" + "layeh.com/radius" ) type Code uint8 @@ -26,22 +27,30 @@ type Packet struct { msgType protocol.Type rawPayload []byte Payload protocol.Payload + + stm StateManager + state string + endModifier func(p *radius.Packet) *radius.Packet } type PayloadWriter struct{} -func emptyPayload(t protocol.Type) protocol.Payload { - switch t { - case protocol.TypeIdentity: - return &IdentityPayload{} - case tls.TypeTLS: - return &tls.Payload{} +func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) { + for _, cons := range stm.GetEAPSettings().Protocols { + if np := cons(); np.Type() == t { + return np, nil + } } - return nil + return nil, fmt.Errorf("unsupported EAP type %d", t) } -func Decode(raw []byte) (*Packet, error) { - packet := &Packet{} +func Decode(stm StateManager, raw []byte) (*Packet, error) { + packet := &Packet{ + stm: stm, + endModifier: func(p *radius.Packet) *radius.Packet { + return p + }, + } packet.code = Code(raw[0]) packet.id = raw[1] packet.length = binary.BigEndian.Uint16(raw[2:]) @@ -51,10 +60,14 @@ func Decode(raw []byte) (*Packet, error) { if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) { packet.msgType = protocol.Type(raw[4]) } - packet.Payload = emptyPayload(packet.msgType) + p, err := emptyPayload(stm, packet.msgType) + if err != nil { + return nil, err + } + packet.Payload = p packet.rawPayload = raw[5:] - log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw") - err := packet.Payload.Decode(raw[5:]) + log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw") + err = packet.Payload.Decode(raw[5:]) if err != nil { return nil, err } diff --git a/internal/outpost/radius/eap/payload_identity.go b/internal/outpost/radius/eap/payload_identity.go deleted file mode 100644 index 2ba5279e2e..0000000000 --- a/internal/outpost/radius/eap/payload_identity.go +++ /dev/null @@ -1,14 +0,0 @@ -package eap - -type IdentityPayload struct { - Identity string -} - -func (ip *IdentityPayload) Decode(raw []byte) error { - ip.Identity = string(raw) - return nil -} - -func (ip *IdentityPayload) Encode() ([]byte, error) { - return []byte{}, nil -} diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go index ca0b732b82..f08b18df4f 100644 --- a/internal/outpost/radius/eap/protocol/context.go +++ b/internal/outpost/radius/eap/protocol/context.go @@ -11,15 +11,17 @@ const ( StatusUnknown Status = iota StatusSuccess StatusError + StatusNextProtocol ) type Context interface { Packet() *radius.Request ProtocolSettings() interface{} - GetProtocolState(def func(Context) interface{}) interface{} + GetProtocolState() interface{} SetProtocolState(interface{}) + IsProtocolStart() bool EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) Log() *log.Entry diff --git a/internal/outpost/radius/eap/protocol/packet.go b/internal/outpost/radius/eap/protocol/packet.go index e3e66459c6..8b698d768e 100644 --- a/internal/outpost/radius/eap/protocol/packet.go +++ b/internal/outpost/radius/eap/protocol/packet.go @@ -3,11 +3,9 @@ package protocol type Payload interface { Decode(raw []byte) error Encode() ([]byte, error) + Handle(ctx Context) Payload + Type() Type + Offerable() bool } type Type uint8 - -const ( - TypeIdentity Type = 1 - TypeMD5Challenge Type = 4 -) diff --git a/internal/outpost/radius/eap/state.go b/internal/outpost/radius/eap/state.go index bf0e12ec8c..ed8793ba32 100644 --- a/internal/outpost/radius/eap/state.go +++ b/internal/outpost/radius/eap/state.go @@ -1,13 +1,17 @@ package eap import ( + "errors" "slices" "goauthentik.io/internal/outpost/radius/eap/protocol" ) +type ProtocolConstructor func() protocol.Payload + type Settings struct { - ProtocolsToOffer []protocol.Type + Protocols []ProtocolConstructor + ProtocolPriority []protocol.Type ProtocolSettings map[protocol.Type]interface{} } @@ -18,13 +22,23 @@ type StateManager interface { } type State struct { - ChallengesToOffer []protocol.Type - TypeState map[protocol.Type]any + Protocols []ProtocolConstructor + ProtocolIndex int + ProtocolPriority []protocol.Type + TypeState map[protocol.Type]any +} + +func (st *State) GetNextProtocol() (protocol.Type, error) { + if st.ProtocolIndex >= len(st.ProtocolPriority) { + return protocol.Type(0), errors.New("no more protocols to offer") + } + return st.ProtocolPriority[st.ProtocolIndex], nil } func BlankState(settings Settings) *State { return &State{ - ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer), - TypeState: map[protocol.Type]any{}, + Protocols: slices.Clone(settings.Protocols), + ProtocolPriority: slices.Clone(settings.ProtocolPriority), + TypeState: map[protocol.Type]any{}, } } diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index 49509969bd..a24c44dc5e 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -21,6 +21,10 @@ const staleConnectionTimeout = 10 const TypeTLS protocol.Type = 13 +func Protocol() protocol.Payload { + return &Payload{} +} + type Payload struct { Flags Flag Length uint32 @@ -29,6 +33,14 @@ type Payload struct { st *State } +func (p *Payload) Type() protocol.Type { + return TypeTLS +} + +func (p *Payload) Offerable() bool { + return true +} + func (p *Payload) Decode(raw []byte) error { p.Flags = Flag(raw[0]) raw = raw[1:] @@ -65,15 +77,16 @@ func (p *Payload) Encode() ([]byte, error) { } func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { - p.st = ctx.GetProtocolState(NewState).(*State) - defer ctx.SetProtocolState(p.st) - if !p.st.HasStarted { - ctx.Log().Debug("TLS: handshake starting") - p.st.HasStarted = true + defer func() { + ctx.SetProtocolState(p.st) + }() + if ctx.IsProtocolStart() { + p.st = NewState(ctx).(*State) return &Payload{ Flags: FlagTLSStart, } } + p.st = ctx.GetProtocolState().(*State) if p.st.TLS == nil { p.tlsInit(ctx) diff --git a/internal/outpost/radius/eap/tls/state.go b/internal/outpost/radius/eap/tls/state.go index 628679cc52..ccb9cd9577 100644 --- a/internal/outpost/radius/eap/tls/state.go +++ b/internal/outpost/radius/eap/tls/state.go @@ -8,7 +8,6 @@ import ( ) type State struct { - HasStarted bool RemainingChunks [][]byte HandshakeDone bool FinalStatus protocol.Status diff --git a/internal/outpost/radius/handle_access_request.go b/internal/outpost/radius/handle_access_request.go index c9b2a0927c..9fe2cebc95 100644 --- a/internal/outpost/radius/handle_access_request.go +++ b/internal/outpost/radius/handle_access_request.go @@ -12,6 +12,8 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/flow" "goauthentik.io/internal/outpost/radius/eap" + "goauthentik.io/internal/outpost/radius/eap/identity" + "goauthentik.io/internal/outpost/radius/eap/legacy_nak" "goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/tls" "goauthentik.io/internal/outpost/radius/metrics" @@ -111,12 +113,12 @@ func (rs *RadiusServer) Handle_AccessRequest_PAP(w radius.ResponseWriter, r *Rad func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *RadiusRequest) { er := rfc2869.EAPMessage_Get(r.Packet) - ep, err := eap.Decode(er) + ep, err := eap.Decode(r.pi, er) if err != nil { rs.log.WithError(err).Warning("failed to parse EAP packet") return } - ep.Handle(r.pi, w, r.Request) + ep.HandleRadiusPacket(w, r.Request) } func (pi *ProviderInstance) GetEAPState(key string) *eap.State { @@ -128,22 +130,30 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) { } func (pi *ProviderInstance) GetEAPSettings() eap.Settings { + protocols := []eap.ProtocolConstructor{ + identity.Protocol, + legacy_nak.Protocol, + } + certId := pi.certId if certId == "" { return eap.Settings{ - ProtocolsToOffer: []protocol.Type{}, + Protocols: protocols, } } cert := pi.s.cryptoStore.Get(certId) if cert == nil { return eap.Settings{ - ProtocolsToOffer: []protocol.Type{}, + Protocols: protocols, } } return eap.Settings{ - ProtocolsToOffer: []protocol.Type{tls.TypeTLS}, + Protocols: append(protocols, tls.Protocol), + ProtocolPriority: []protocol.Type{ + tls.TypeTLS, + }, ProtocolSettings: map[protocol.Type]interface{}{ tls.TypeTLS: tls.Settings{ Config: &ttls.Config{