diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go index fb833523f1..92fd9a58f6 100644 --- a/internal/outpost/radius/eap/context.go +++ b/internal/outpost/radius/eap/context.go @@ -30,7 +30,10 @@ func (ctx *context) Log() *log.Entry { return ctx.log } func (ctx *context) HandleInnerEAP(p protocol.Payload, st protocol.StateManager) (protocol.Payload, error) { return ctx.handleInner(p, st) } -func (ctx *context) Inner(p protocol.Payload, t protocol.Type) protocol.Context { +func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radius.Packet) *radius.Packet) protocol.Context { + if ctx.endModifier == nil { + ctx.endModifier = pmf + } return &context{ req: ctx.req, rootPayload: ctx.rootPayload, @@ -51,12 +54,9 @@ func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packe return } ctx.endStatus = st - if mf == nil { - mf = func(p *radius.Packet) *radius.Packet { - return p - } + if mf != nil { + ctx.endModifier = mf } - ctx.endModifier = mf } func (ctx *context) callEndModifier(p *radius.Packet) *radius.Packet { @@ -64,6 +64,7 @@ func (ctx *context) callEndModifier(p *radius.Packet) *radius.Packet { p = ctx.parent.callEndModifier(p) } if ctx.endModifier != nil { + ctx.log.Debug("Running end modifier") p = ctx.endModifier(p) } return p diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index a395a3a43d..8219a66896 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -106,7 +106,7 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren var ctx *context if parentContext != nil { - ctx = parentContext.Inner(np, t).(*context) + ctx = parentContext.Inner(np, t, nil).(*context) } else { ctx = &context{ req: p.r, @@ -116,7 +116,7 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren settings: stm.GetEAPSettings().ProtocolSettings[t], } ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager) (protocol.Payload, error) { - return p.handleEAP(pp, sm, ctx.Inner(pp, pp.Type()).(*context)) + return p.handleEAP(pp, sm, ctx.Inner(pp, pp.Type(), nil).(*context)) } } if !np.Offerable() { diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go index 2a086d3008..63a0d09480 100644 --- a/internal/outpost/radius/eap/protocol/context.go +++ b/internal/outpost/radius/eap/protocol/context.go @@ -25,7 +25,7 @@ type Context interface { IsProtocolStart(p Type) bool HandleInnerEAP(Payload, StateManager) (Payload, error) - Inner(Payload, Type) Context + Inner(Payload, Type, func(p *radius.Packet) *radius.Packet) Context EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) Log() *log.Entry diff --git a/internal/outpost/radius/eap/protocol/tls/inner.go b/internal/outpost/radius/eap/protocol/tls/inner.go index 4aff0be552..f5eee28a28 100644 --- a/internal/outpost/radius/eap/protocol/tls/inner.go +++ b/internal/outpost/radius/eap/protocol/tls/inner.go @@ -2,6 +2,8 @@ package tls import ( "goauthentik.io/internal/outpost/radius/eap/protocol" + "layeh.com/radius" + "layeh.com/radius/vendors/microsoft" ) func (p *Payload) innerHandler(ctx protocol.Context) { @@ -23,7 +25,12 @@ func (p *Payload) innerHandler(ctx protocol.Context) { ctx.EndInnerProtocol(protocol.StatusError, nil) return } - pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type())) + pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type(), func(r *radius.Packet) *radius.Packet { + ctx.Log().Debug("TLS: Adding MPPE Keys") + microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) + microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) + return r + })) enc, err := pl.Encode() if err != nil { ctx.Log().WithError(err).Warning("TLS: failed to encode inner protocol")