From 91c87b7c3c9ad85d1cf8325d9e2936aed247618e Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Fri, 16 May 2025 15:16:26 +0200 Subject: [PATCH] ok this works kinda Signed-off-by: Jens Langhammer --- internal/outpost/radius/eap/context.go | 21 +++++++---- internal/outpost/radius/eap/handler.go | 35 +++++++++++++------ internal/outpost/radius/eap/packet.go | 35 ++++++++----------- .../outpost/radius/eap/protocol/context.go | 10 +++++- .../outpost/radius/eap/protocol/packet.go | 7 ++++ internal/outpost/radius/eap/state.go | 16 +++++---- internal/outpost/radius/eap/tls/payload.go | 11 +++--- .../outpost/radius/handle_access_request.go | 7 ++-- 8 files changed, 91 insertions(+), 51 deletions(-) diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go index 21e86d8fa7..0b6b8a31d7 100644 --- a/internal/outpost/radius/eap/context.go +++ b/internal/outpost/radius/eap/context.go @@ -7,24 +7,31 @@ import ( ) type context struct { - state interface{} - log *log.Entry + state interface{} + log *log.Entry + settings interface{} + endStatus protocol.Status + endModifier func(p *radius.Packet) *radius.Packet } func (ctx context) ProtocolSettings() interface{} { - return nil + return ctx.settings } -func (ctx context) GetProtocolState(def func(protocol.Context) interface{}) interface{} { +func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} { + if ctx.state == nil { + ctx.state = def(ctx) + } return ctx.state } -func (ctx context) SetProtocolState(st interface{}) { +func (ctx *context) SetProtocolState(st interface{}) { ctx.state = st } -func (ctx context) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) { - +func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { + ctx.endStatus = st + ctx.endModifier = mf } func (ctx context) Log() *log.Entry { diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 2e5523d2c3..1267806aea 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -29,9 +29,10 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac } nextChallengeToOffer := st.ChallengesToOffer[0] - ctx := context{ - state: st.TypeState[nextChallengeToOffer], - log: log.WithField("type", nextChallengeToOffer), + ctx := &context{ + state: st.TypeState[nextChallengeToOffer], + log: log.WithField("type", nextChallengeToOffer), + settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer], } res := p.GetChallengeForType(ctx, nextChallengeToOffer) @@ -39,11 +40,24 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac stm.SetEAPState(rst, st) rres := r.Response(radius.CodeAccessChallenge) - if p, ok := res.Payload.(protocol.EmptyPayload); ok { - // TODO: This is a bit hacky here + switch ctx.endStatus { + case protocol.StatusSuccess: res.code = CodeSuccess res.id -= 1 - rres = p.ModifyPacket(rres) + rres = ctx.endModifier(rres) + st.ChallengesToOffer = st.ChallengesToOffer[1:] + if len(st.ChallengesToOffer) < 1 { + rres.Code = radius.CodeAccessAccept + } + case protocol.StatusError: + res.code = CodeFailure + res.id -= 1 + st.ChallengesToOffer = st.ChallengesToOffer[1:] + rres = ctx.endModifier(rres) + if len(st.ChallengesToOffer) < 1 { + rres.Code = radius.CodeAccessReject + } + case protocol.StatusUnknown: } rfc2865.State_SetString(rres, rst) eapEncoded, err := res.Encode() @@ -59,7 +73,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac } } -func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet { +func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet { res := &Packet{ code: CodeRequest, id: p.id + 1, @@ -67,7 +81,7 @@ func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet { } var payload any switch t { - case TypeTLS: + case tls.TypeTLS: // TODO: rewrite this if _, ok := p.Payload.(*tls.Payload); !ok { p.Payload = &tls.Payload{} @@ -76,8 +90,9 @@ func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet { // this payload = p.Payload.(*tls.Payload).Handle(ctx) } - // st.TypeState[t] = tst - res.Payload = payload.(protocol.Payload) + if payload != nil { + res.Payload = payload.(protocol.Payload) + } return res } diff --git a/internal/outpost/radius/eap/packet.go b/internal/outpost/radius/eap/packet.go index 640e27c192..ebeb45bb2c 100644 --- a/internal/outpost/radius/eap/packet.go +++ b/internal/outpost/radius/eap/packet.go @@ -16,32 +16,25 @@ const ( CodeRequest Code = 1 CodeResponse Code = 2 CodeSuccess Code = 3 -) - -type Type uint8 - -const ( - TypeIdentity Type = 1 - TypeMD5Challenge Type = 4 - TypeTLS Type = 13 + CodeFailure Code = 4 ) type Packet struct { code Code id uint8 length uint16 - msgType Type + msgType protocol.Type rawPayload []byte Payload protocol.Payload } type PayloadWriter struct{} -func emptyPayload(t Type) protocol.Payload { +func emptyPayload(t protocol.Type) protocol.Payload { switch t { - case TypeIdentity: + case protocol.TypeIdentity: return &IdentityPayload{} - case TypeTLS: + case tls.TypeTLS: return &tls.Payload{} } return nil @@ -56,7 +49,7 @@ func Decode(raw []byte) (*Packet, error) { return nil, errors.New("mismatched packet length") } if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) { - packet.msgType = Type(raw[4]) + packet.msgType = protocol.Type(raw[4]) } packet.Payload = emptyPayload(packet.msgType) packet.rawPayload = raw[5:] @@ -73,14 +66,16 @@ func (p *Packet) Encode() ([]byte, error) { buff[0] = uint8(p.code) buff[1] = uint8(p.id) - payloadBuffer, err := p.Payload.Encode() - if err != nil { - return buff, err + if p.Payload != nil { + payloadBuffer, err := p.Payload.Encode() + if err != nil { + return buff, err + } + if p.code == CodeRequest || p.code == CodeResponse { + buff = append(buff, uint8(p.msgType)) + } + buff = append(buff, payloadBuffer...) } - if p.code == CodeRequest || p.code == CodeResponse { - buff = append(buff, uint8(p.msgType)) - } - buff = append(buff, payloadBuffer...) binary.BigEndian.PutUint16(buff[2:], uint16(len(buff))) return buff, nil } diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go index 4b111d08fe..425b1faa01 100644 --- a/internal/outpost/radius/eap/protocol/context.go +++ b/internal/outpost/radius/eap/protocol/context.go @@ -5,6 +5,14 @@ import ( "layeh.com/radius" ) +type Status int + +const ( + StatusUnknown Status = iota + StatusSuccess + StatusError +) + type Context interface { // GlobalState() @@ -12,7 +20,7 @@ type Context interface { GetProtocolState(def func(Context) interface{}) interface{} SetProtocolState(interface{}) - EndInnerProtocol(func(p *radius.Packet) *radius.Packet) + EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) Log() *log.Entry } diff --git a/internal/outpost/radius/eap/protocol/packet.go b/internal/outpost/radius/eap/protocol/packet.go index 9ff5e955a2..e3e66459c6 100644 --- a/internal/outpost/radius/eap/protocol/packet.go +++ b/internal/outpost/radius/eap/protocol/packet.go @@ -4,3 +4,10 @@ type Payload interface { Decode(raw []byte) error Encode() ([]byte, error) } + +type Type uint8 + +const ( + TypeIdentity Type = 1 + TypeMD5Challenge Type = 4 +) diff --git a/internal/outpost/radius/eap/state.go b/internal/outpost/radius/eap/state.go index cf0d59fa1d..bf0e12ec8c 100644 --- a/internal/outpost/radius/eap/state.go +++ b/internal/outpost/radius/eap/state.go @@ -1,10 +1,14 @@ package eap -import "slices" +import ( + "slices" + + "goauthentik.io/internal/outpost/radius/eap/protocol" +) type Settings struct { - ProtocolsToOffer []Type - ProtocolSettings map[Type]interface{} + ProtocolsToOffer []protocol.Type + ProtocolSettings map[protocol.Type]interface{} } type StateManager interface { @@ -14,13 +18,13 @@ type StateManager interface { } type State struct { - ChallengesToOffer []Type - TypeState map[Type]any + ChallengesToOffer []protocol.Type + TypeState map[protocol.Type]any } func BlankState(settings Settings) *State { return &State{ ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer), - TypeState: map[Type]any{}, + TypeState: map[protocol.Type]any{}, } } diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index 2919f4de98..91878636e6 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -18,6 +18,8 @@ import ( const maxChunkSize = 1000 const staleConnectionTimeout = 10 +const TypeTLS protocol.Type = 13 + type Payload struct { Flags Flag Length uint32 @@ -102,8 +104,7 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { } if p.st.Conn.writer.Len() == 0 && p.st.HandshakeDone { defer p.st.ContextCancel() - ctx.EndInnerProtocol(func(r *radius.Packet) *radius.Packet { - r.Code = radius.CodeAccessAccept + ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet { microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) return r @@ -128,7 +129,9 @@ func (p *Payload) tlsInit(ctx protocol.Context) { err := p.st.TLS.HandshakeContext(p.st.Context) if err != nil { log.WithError(err).Debug("TLS: Handshake error") - // TODO: Send a NAK to the client + ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet { + return p + }) return } log.Debug("TLS: handshake done") @@ -150,7 +153,7 @@ func (p *Payload) tlsHandshakeFinished() { case tls.VersionTLS13: log.Debugf("TLS: Version %d (1.3)", cs.Version) label = "EXPORTER_EAP_TLS_Key_Material" - context = []byte{13} + context = []byte{byte(TypeTLS)} } ksm, err := cs.ExportKeyingMaterial(label, context, 64+64) log.Debugf("TLS: ksm % x %v", ksm, err) diff --git a/internal/outpost/radius/handle_access_request.go b/internal/outpost/radius/handle_access_request.go index 67631b56fb..8c515ad4e0 100644 --- a/internal/outpost/radius/handle_access_request.go +++ b/internal/outpost/radius/handle_access_request.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/flow" "goauthentik.io/internal/outpost/radius/eap" + "goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/tls" "goauthentik.io/internal/outpost/radius/metrics" "layeh.com/radius" @@ -134,9 +135,9 @@ func (pi *ProviderInstance) GetEAPSettings() eap.Settings { } return eap.Settings{ - ProtocolsToOffer: []eap.Type{eap.TypeTLS}, - ProtocolSettings: map[eap.Type]interface{}{ - eap.TypeTLS: tls.Settings{ + ProtocolsToOffer: []protocol.Type{tls.TypeTLS}, + ProtocolSettings: map[protocol.Type]interface{}{ + tls.TypeTLS: tls.Settings{ Config: &ttls.Config{ Certificates: []ttls.Certificate{cert}, ClientAuth: ttls.RequireAnyClientCert,