diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go index 92fd9a58f6..6c7e28cf22 100644 --- a/internal/outpost/radius/eap/context.go +++ b/internal/outpost/radius/eap/context.go @@ -16,8 +16,7 @@ type context struct { settings interface{} parent *context endStatus protocol.Status - endModifier func(p *radius.Packet) *radius.Packet - handleInner func(protocol.Payload, protocol.StateManager) (protocol.Payload, error) + handleInner func(protocol.Payload, protocol.StateManager, protocol.Context) (protocol.Payload, error) } func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload } @@ -28,13 +27,10 @@ func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p] func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil } func (ctx *context) Log() *log.Entry { return ctx.log } func (ctx *context) HandleInnerEAP(p protocol.Payload, st protocol.StateManager) (protocol.Payload, error) { - return ctx.handleInner(p, st) + return ctx.handleInner(p, st, ctx) } -func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radius.Packet) *radius.Packet) protocol.Context { - if ctx.endModifier == nil { - ctx.endModifier = pmf - } - return &context{ +func (ctx *context) Inner(p protocol.Payload, t protocol.Type) protocol.Context { + nctx := &context{ req: ctx.req, rootPayload: ctx.rootPayload, typeState: ctx.typeState, @@ -43,29 +39,17 @@ func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radiu parent: ctx, handleInner: ctx.handleInner, } + nctx.log.Debug("Creating inner context") + return nctx } -func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { +func (ctx *context) EndInnerProtocol(st protocol.Status) { ctx.log.Info("Ending protocol") if ctx.parent != nil { - ctx.parent.EndInnerProtocol(st, mf) + ctx.parent.EndInnerProtocol(st) return } if ctx.endStatus != protocol.StatusUnknown { return } ctx.endStatus = st - if mf != nil { - ctx.endModifier = mf - } -} - -func (ctx *context) callEndModifier(p *radius.Packet) *radius.Packet { - if ctx.parent != nil { - p = ctx.parent.callEndModifier(p) - } - if ctx.endModifier != nil { - ctx.log.Debug("Running end modifier") - p = ctx.endModifier(p) - } - return p } diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 8219a66896..ba841647b6 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -39,7 +39,6 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) rres := r.Response(radius.CodeAccessReject) if err == nil { - rres = p.endModifier(rres) switch rp.eap.Code { case protocol.CodeRequest: rres.Code = radius.CodeAccessChallenge @@ -52,6 +51,13 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) rres.Code = radius.CodeAccessReject log.WithError(err).Debug("Rejecting request") } + for _, mod := range p.responseModifiers { + err := mod.ModifyRADIUSResponse(rres, r.Packet) + if err != nil { + log.WithError(err).Warning("Root-EAP: failed to modify response packet") + break + } + } rfc2865.State_SetString(rres, p.state) eapEncoded, err := rp.Encode() @@ -106,7 +112,8 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren var ctx *context if parentContext != nil { - ctx = parentContext.Inner(np, t, nil).(*context) + ctx = parentContext.Inner(np, t).(*context) + ctx.settings = stm.GetEAPSettings().ProtocolSettings[np.Type()] } else { ctx = &context{ req: p.r, @@ -115,8 +122,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t), settings: stm.GetEAPSettings().ProtocolSettings[t], } - ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager) (protocol.Payload, error) { - return p.handleEAP(pp, sm, ctx.Inner(pp, pp.Type(), nil).(*context)) + ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager, ctx protocol.Context) (protocol.Payload, error) { + // cctx := ctx.Inner(np, np.Type(), nil).(*context) + return p.handleEAP(pp, sm, ctx.(*context)) } } if !np.Offerable() { @@ -141,8 +149,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren stm.SetEAPState(p.state, st) - if ctx.endModifier != nil { - p.endModifier = ctx.callEndModifier + if rm, ok := np.(protocol.ResponseModifier); ok { + ctx.log.Debug("Root-EAP: Registered response modifier") + p.responseModifiers = append(p.responseModifiers, rm) } switch ctx.endStatus { diff --git a/internal/outpost/radius/eap/packet.go b/internal/outpost/radius/eap/packet.go index b9ca23c5c1..80f2d343ed 100644 --- a/internal/outpost/radius/eap/packet.go +++ b/internal/outpost/radius/eap/packet.go @@ -7,11 +7,11 @@ import ( ) type Packet struct { - r *radius.Request - eap *eap.Payload - stm protocol.StateManager - state string - endModifier func(p *radius.Packet) *radius.Packet + r *radius.Request + eap *eap.Payload + stm protocol.StateManager + state string + responseModifiers []protocol.ResponseModifier } func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) { @@ -19,10 +19,8 @@ func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) { eap: &eap.Payload{ Settings: stm.GetEAPSettings(), }, - stm: stm, - endModifier: func(p *radius.Packet) *radius.Packet { - return p - }, + stm: stm, + responseModifiers: []protocol.ResponseModifier{}, } err := packet.eap.Decode(raw) if err != nil { diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go index 63a0d09480..1c6e95e68c 100644 --- a/internal/outpost/radius/eap/protocol/context.go +++ b/internal/outpost/radius/eap/protocol/context.go @@ -25,8 +25,8 @@ type Context interface { IsProtocolStart(p Type) bool HandleInnerEAP(Payload, StateManager) (Payload, error) - Inner(Payload, Type, func(p *radius.Packet) *radius.Packet) Context - EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) + Inner(Payload, Type) Context + EndInnerProtocol(Status) Log() *log.Entry } diff --git a/internal/outpost/radius/eap/protocol/identity/payload.go b/internal/outpost/radius/eap/protocol/identity/payload.go index bb9dc850fd..f78dbb8634 100644 --- a/internal/outpost/radius/eap/protocol/identity/payload.go +++ b/internal/outpost/radius/eap/protocol/identity/payload.go @@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { if ctx.IsProtocolStart(TypeIdentity) { - ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil) + ctx.EndInnerProtocol(protocol.StatusNextProtocol) } return nil } diff --git a/internal/outpost/radius/eap/protocol/legacy_nak/payload.go b/internal/outpost/radius/eap/protocol/legacy_nak/payload.go index 468e241a01..fcd742f801 100644 --- a/internal/outpost/radius/eap/protocol/legacy_nak/payload.go +++ b/internal/outpost/radius/eap/protocol/legacy_nak/payload.go @@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { if ctx.IsProtocolStart(TypeLegacyNAK) { - ctx.EndInnerProtocol(protocol.StatusError, nil) + ctx.EndInnerProtocol(protocol.StatusError) } return nil } diff --git a/internal/outpost/radius/eap/protocol/mschapv2/payload.go b/internal/outpost/radius/eap/protocol/mschapv2/payload.go index a27d784fdf..aa972e5cd7 100644 --- a/internal/outpost/radius/eap/protocol/mschapv2/payload.go +++ b/internal/outpost/radius/eap/protocol/mschapv2/payload.go @@ -160,20 +160,29 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { p.st.IsProtocolEnded = true return ep } else if p.st.IsProtocolEnded { - ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet { - if len(microsoft.MSMPPERecvKey_Get(r, ctx.Packet().Packet)) < 1 { - microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey) - } - if len(microsoft.MSMPPESendKey_Get(r, ctx.Packet().Packet)) < 1 { - microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey) - } - return r - }) + ctx.EndInnerProtocol(protocol.StatusSuccess) return &Payload{} } return response } +func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error { + if p.st == nil || p.st.AuthResponse == nil { + return nil + } + if r.Code != radius.CodeAccessAccept { + return nil + } + log.Debug("MSCHAPv2: Radius modifier") + if len(microsoft.MSMPPERecvKey_Get(r, q)) < 1 { + microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey) + } + if len(microsoft.MSMPPESendKey_Get(r, q)) < 1 { + microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey) + } + return nil +} + func (p *Payload) Offerable() bool { return true } diff --git a/internal/outpost/radius/eap/protocol/packet.go b/internal/outpost/radius/eap/protocol/packet.go index a1606864ca..ce5279a8a7 100644 --- a/internal/outpost/radius/eap/protocol/packet.go +++ b/internal/outpost/radius/eap/protocol/packet.go @@ -1,5 +1,18 @@ package protocol +import "layeh.com/radius" + +type Type uint8 + +type Code uint8 + +const ( + CodeRequest Code = 1 + CodeResponse Code = 2 + CodeSuccess Code = 3 + CodeFailure Code = 4 +) + type Payload interface { Decode(raw []byte) error Encode() ([]byte, error) @@ -13,13 +26,6 @@ type Inner interface { HasInner() Payload } -type Type uint8 - -type Code uint8 - -const ( - CodeRequest Code = 1 - CodeResponse Code = 2 - CodeSuccess Code = 3 - CodeFailure Code = 4 -) +type ResponseModifier interface { + ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error +} diff --git a/internal/outpost/radius/eap/protocol/tls/inner.go b/internal/outpost/radius/eap/protocol/tls/inner.go index f5eee28a28..193bf63bb1 100644 --- a/internal/outpost/radius/eap/protocol/tls/inner.go +++ b/internal/outpost/radius/eap/protocol/tls/inner.go @@ -2,8 +2,6 @@ package tls import ( "goauthentik.io/internal/outpost/radius/eap/protocol" - "layeh.com/radius" - "layeh.com/radius/vendors/microsoft" ) func (p *Payload) innerHandler(ctx protocol.Context) { @@ -13,7 +11,7 @@ func (p *Payload) innerHandler(ctx protocol.Context) { n, err := p.st.TLS.Read(d) if err != nil { ctx.Log().WithError(err).Warning("TLS: Failed to read from TLS connection") - ctx.EndInnerProtocol(protocol.StatusError, nil) + ctx.EndInnerProtocol(protocol.StatusError) return } // Truncate data to the size we read @@ -22,25 +20,20 @@ func (p *Payload) innerHandler(ctx protocol.Context) { err := p.Inner.Decode(d) if err != nil { ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol") - ctx.EndInnerProtocol(protocol.StatusError, nil) + ctx.EndInnerProtocol(protocol.StatusError) return } - pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type(), func(r *radius.Packet) *radius.Packet { - ctx.Log().Debug("TLS: Adding MPPE Keys") - microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) - microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) - return r - })) + pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type())) enc, err := pl.Encode() if err != nil { ctx.Log().WithError(err).Warning("TLS: failed to encode inner protocol") - ctx.EndInnerProtocol(protocol.StatusError, nil) + ctx.EndInnerProtocol(protocol.StatusError) return } _, err = p.st.TLS.Write(enc) if err != nil { ctx.Log().WithError(err).Warning("TLS: failed to write to TLS") - ctx.EndInnerProtocol(protocol.StatusError, nil) + ctx.EndInnerProtocol(protocol.StatusError) return } } diff --git a/internal/outpost/radius/eap/protocol/tls/payload.go b/internal/outpost/radius/eap/protocol/tls/payload.go index fb6d95a2c8..2101058084 100644 --- a/internal/outpost/radius/eap/protocol/tls/payload.go +++ b/internal/outpost/radius/eap/protocol/tls/payload.go @@ -144,17 +144,32 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { retry.MaxDelay(100*time.Millisecond), retry.Attempts(0), ) - ctx.EndInnerProtocol(pst, func(r *radius.Packet) *radius.Packet { - ctx.Log().Debug("TLS: Adding MPPE Keys") - microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) - microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) - return r - }) + ctx.EndInnerProtocol(pst) return nil } return p.startChunkedTransfer(p.st.Conn.OutboundData()) } +func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error { + if r.Code != radius.CodeAccessAccept { + return nil + } + if p.st == nil || !p.st.HandshakeDone { + return nil + } + log.Debug("TLS: Adding MPPE Keys") + // TLS overrides other protocols' MPPE keys + if len(microsoft.MSMPPERecvKey_Get(r, q)) > 0 { + microsoft.MSMPPERecvKey_Del(r) + } + if len(microsoft.MSMPPESendKey_Get(r, q)) > 0 { + microsoft.MSMPPESendKey_Del(r) + } + microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) + microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) + return nil +} + func (p *Payload) tlsInit(ctx protocol.Context) { ctx.Log().Debug("TLS: no TLS connection in state yet, starting connection") p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) @@ -181,9 +196,7 @@ func (p *Payload) tlsInit(ctx protocol.Context) { if err != nil { ctx.Log().WithError(err).Debug("TLS: Handshake error") p.st.FinalStatus = protocol.StatusError - ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet { - return p - }) + ctx.EndInnerProtocol(protocol.StatusError) return } ctx.Log().Debug("TLS: handshake done") diff --git a/internal/outpost/radius/handle_access_request.go b/internal/outpost/radius/handle_access_request.go index 72e43ea267..2e1cc7ec43 100644 --- a/internal/outpost/radius/handle_access_request.go +++ b/internal/outpost/radius/handle_access_request.go @@ -153,10 +153,7 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings { return protocol.Settings{ Protocols: append(protocols, tls.Protocol, peap.Protocol), - ProtocolPriority: []protocol.Type{ - tls.TypeTLS, - peap.TypePEAP, - }, + ProtocolPriority: []protocol.Type{tls.TypeTLS, peap.TypePEAP}, ProtocolSettings: map[protocol.Type]interface{}{ tls.TypeTLS: tls.Settings{ Config: &ttls.Config{ @@ -197,6 +194,13 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings { InnerProtocols: protocol.Settings{ Protocols: append(protocols, mschapv2.Protocol), ProtocolPriority: []protocol.Type{mschapv2.TypeMSCHAPv2}, + ProtocolSettings: map[protocol.Type]interface{}{ + mschapv2.TypeMSCHAPv2: mschapv2.Settings{ + AuthenticateRequest: mschapv2.DebugStaticCredentials( + []byte("foo"), []byte("bar"), + ), + }, + }, }, }, },