start reworking response modification

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-24 19:49:41 +02:00
parent 9045f5ba73
commit f5eb827d14
11 changed files with 103 additions and 87 deletions

View File

@ -16,8 +16,7 @@ type context struct {
settings interface{} settings interface{}
parent *context parent *context
endStatus protocol.Status endStatus protocol.Status
endModifier func(p *radius.Packet) *radius.Packet handleInner func(protocol.Payload, protocol.StateManager, protocol.Context) (protocol.Payload, error)
handleInner func(protocol.Payload, protocol.StateManager) (protocol.Payload, error)
} }
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload } func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
@ -28,13 +27,10 @@ func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p]
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil } func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
func (ctx *context) Log() *log.Entry { return ctx.log } func (ctx *context) Log() *log.Entry { return ctx.log }
func (ctx *context) HandleInnerEAP(p protocol.Payload, st protocol.StateManager) (protocol.Payload, error) { func (ctx *context) HandleInnerEAP(p protocol.Payload, st protocol.StateManager) (protocol.Payload, error) {
return ctx.handleInner(p, st) return ctx.handleInner(p, st, ctx)
} }
func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radius.Packet) *radius.Packet) protocol.Context { func (ctx *context) Inner(p protocol.Payload, t protocol.Type) protocol.Context {
if ctx.endModifier == nil { nctx := &context{
ctx.endModifier = pmf
}
return &context{
req: ctx.req, req: ctx.req,
rootPayload: ctx.rootPayload, rootPayload: ctx.rootPayload,
typeState: ctx.typeState, typeState: ctx.typeState,
@ -43,29 +39,17 @@ func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radiu
parent: ctx, parent: ctx,
handleInner: ctx.handleInner, handleInner: ctx.handleInner,
} }
nctx.log.Debug("Creating inner context")
return nctx
} }
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { func (ctx *context) EndInnerProtocol(st protocol.Status) {
ctx.log.Info("Ending protocol") ctx.log.Info("Ending protocol")
if ctx.parent != nil { if ctx.parent != nil {
ctx.parent.EndInnerProtocol(st, mf) ctx.parent.EndInnerProtocol(st)
return return
} }
if ctx.endStatus != protocol.StatusUnknown { if ctx.endStatus != protocol.StatusUnknown {
return return
} }
ctx.endStatus = st ctx.endStatus = st
if mf != nil {
ctx.endModifier = mf
}
}
func (ctx *context) callEndModifier(p *radius.Packet) *radius.Packet {
if ctx.parent != nil {
p = ctx.parent.callEndModifier(p)
}
if ctx.endModifier != nil {
ctx.log.Debug("Running end modifier")
p = ctx.endModifier(p)
}
return p
} }

View File

@ -39,7 +39,6 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
rres := r.Response(radius.CodeAccessReject) rres := r.Response(radius.CodeAccessReject)
if err == nil { if err == nil {
rres = p.endModifier(rres)
switch rp.eap.Code { switch rp.eap.Code {
case protocol.CodeRequest: case protocol.CodeRequest:
rres.Code = radius.CodeAccessChallenge rres.Code = radius.CodeAccessChallenge
@ -52,6 +51,13 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
rres.Code = radius.CodeAccessReject rres.Code = radius.CodeAccessReject
log.WithError(err).Debug("Rejecting request") log.WithError(err).Debug("Rejecting request")
} }
for _, mod := range p.responseModifiers {
err := mod.ModifyRADIUSResponse(rres, r.Packet)
if err != nil {
log.WithError(err).Warning("Root-EAP: failed to modify response packet")
break
}
}
rfc2865.State_SetString(rres, p.state) rfc2865.State_SetString(rres, p.state)
eapEncoded, err := rp.Encode() eapEncoded, err := rp.Encode()
@ -106,7 +112,8 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
var ctx *context var ctx *context
if parentContext != nil { if parentContext != nil {
ctx = parentContext.Inner(np, t, nil).(*context) ctx = parentContext.Inner(np, t).(*context)
ctx.settings = stm.GetEAPSettings().ProtocolSettings[np.Type()]
} else { } else {
ctx = &context{ ctx = &context{
req: p.r, req: p.r,
@ -115,8 +122,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t), log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t),
settings: stm.GetEAPSettings().ProtocolSettings[t], settings: stm.GetEAPSettings().ProtocolSettings[t],
} }
ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager) (protocol.Payload, error) { ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager, ctx protocol.Context) (protocol.Payload, error) {
return p.handleEAP(pp, sm, ctx.Inner(pp, pp.Type(), nil).(*context)) // cctx := ctx.Inner(np, np.Type(), nil).(*context)
return p.handleEAP(pp, sm, ctx.(*context))
} }
} }
if !np.Offerable() { if !np.Offerable() {
@ -141,8 +149,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
stm.SetEAPState(p.state, st) stm.SetEAPState(p.state, st)
if ctx.endModifier != nil { if rm, ok := np.(protocol.ResponseModifier); ok {
p.endModifier = ctx.callEndModifier ctx.log.Debug("Root-EAP: Registered response modifier")
p.responseModifiers = append(p.responseModifiers, rm)
} }
switch ctx.endStatus { switch ctx.endStatus {

View File

@ -7,11 +7,11 @@ import (
) )
type Packet struct { type Packet struct {
r *radius.Request r *radius.Request
eap *eap.Payload eap *eap.Payload
stm protocol.StateManager stm protocol.StateManager
state string state string
endModifier func(p *radius.Packet) *radius.Packet responseModifiers []protocol.ResponseModifier
} }
func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) { func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
@ -19,10 +19,8 @@ func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
eap: &eap.Payload{ eap: &eap.Payload{
Settings: stm.GetEAPSettings(), Settings: stm.GetEAPSettings(),
}, },
stm: stm, stm: stm,
endModifier: func(p *radius.Packet) *radius.Packet { responseModifiers: []protocol.ResponseModifier{},
return p
},
} }
err := packet.eap.Decode(raw) err := packet.eap.Decode(raw)
if err != nil { if err != nil {

View File

@ -25,8 +25,8 @@ type Context interface {
IsProtocolStart(p Type) bool IsProtocolStart(p Type) bool
HandleInnerEAP(Payload, StateManager) (Payload, error) HandleInnerEAP(Payload, StateManager) (Payload, error)
Inner(Payload, Type, func(p *radius.Packet) *radius.Packet) Context Inner(Payload, Type) Context
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) EndInnerProtocol(Status)
Log() *log.Entry Log() *log.Entry
} }

View File

@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) {
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeIdentity) { if ctx.IsProtocolStart(TypeIdentity) {
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil) ctx.EndInnerProtocol(protocol.StatusNextProtocol)
} }
return nil return nil
} }

View File

@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) {
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeLegacyNAK) { if ctx.IsProtocolStart(TypeLegacyNAK) {
ctx.EndInnerProtocol(protocol.StatusError, nil) ctx.EndInnerProtocol(protocol.StatusError)
} }
return nil return nil
} }

View File

@ -160,20 +160,29 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
p.st.IsProtocolEnded = true p.st.IsProtocolEnded = true
return ep return ep
} else if p.st.IsProtocolEnded { } else if p.st.IsProtocolEnded {
ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet { ctx.EndInnerProtocol(protocol.StatusSuccess)
if len(microsoft.MSMPPERecvKey_Get(r, ctx.Packet().Packet)) < 1 {
microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey)
}
if len(microsoft.MSMPPESendKey_Get(r, ctx.Packet().Packet)) < 1 {
microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey)
}
return r
})
return &Payload{} return &Payload{}
} }
return response return response
} }
func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error {
if p.st == nil || p.st.AuthResponse == nil {
return nil
}
if r.Code != radius.CodeAccessAccept {
return nil
}
log.Debug("MSCHAPv2: Radius modifier")
if len(microsoft.MSMPPERecvKey_Get(r, q)) < 1 {
microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey)
}
if len(microsoft.MSMPPESendKey_Get(r, q)) < 1 {
microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey)
}
return nil
}
func (p *Payload) Offerable() bool { func (p *Payload) Offerable() bool {
return true return true
} }

View File

@ -1,5 +1,18 @@
package protocol package protocol
import "layeh.com/radius"
type Type uint8
type Code uint8
const (
CodeRequest Code = 1
CodeResponse Code = 2
CodeSuccess Code = 3
CodeFailure Code = 4
)
type Payload interface { type Payload interface {
Decode(raw []byte) error Decode(raw []byte) error
Encode() ([]byte, error) Encode() ([]byte, error)
@ -13,13 +26,6 @@ type Inner interface {
HasInner() Payload HasInner() Payload
} }
type Type uint8 type ResponseModifier interface {
ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error
type Code uint8 }
const (
CodeRequest Code = 1
CodeResponse Code = 2
CodeSuccess Code = 3
CodeFailure Code = 4
)

View File

@ -2,8 +2,6 @@ package tls
import ( import (
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"layeh.com/radius"
"layeh.com/radius/vendors/microsoft"
) )
func (p *Payload) innerHandler(ctx protocol.Context) { func (p *Payload) innerHandler(ctx protocol.Context) {
@ -13,7 +11,7 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
n, err := p.st.TLS.Read(d) n, err := p.st.TLS.Read(d)
if err != nil { if err != nil {
ctx.Log().WithError(err).Warning("TLS: Failed to read from TLS connection") ctx.Log().WithError(err).Warning("TLS: Failed to read from TLS connection")
ctx.EndInnerProtocol(protocol.StatusError, nil) ctx.EndInnerProtocol(protocol.StatusError)
return return
} }
// Truncate data to the size we read // Truncate data to the size we read
@ -22,25 +20,20 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
err := p.Inner.Decode(d) err := p.Inner.Decode(d)
if err != nil { if err != nil {
ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol") ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol")
ctx.EndInnerProtocol(protocol.StatusError, nil) ctx.EndInnerProtocol(protocol.StatusError)
return return
} }
pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type(), func(r *radius.Packet) *radius.Packet { pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type()))
ctx.Log().Debug("TLS: Adding MPPE Keys")
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
return r
}))
enc, err := pl.Encode() enc, err := pl.Encode()
if err != nil { if err != nil {
ctx.Log().WithError(err).Warning("TLS: failed to encode inner protocol") ctx.Log().WithError(err).Warning("TLS: failed to encode inner protocol")
ctx.EndInnerProtocol(protocol.StatusError, nil) ctx.EndInnerProtocol(protocol.StatusError)
return return
} }
_, err = p.st.TLS.Write(enc) _, err = p.st.TLS.Write(enc)
if err != nil { if err != nil {
ctx.Log().WithError(err).Warning("TLS: failed to write to TLS") ctx.Log().WithError(err).Warning("TLS: failed to write to TLS")
ctx.EndInnerProtocol(protocol.StatusError, nil) ctx.EndInnerProtocol(protocol.StatusError)
return return
} }
} }

View File

@ -144,17 +144,32 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
retry.MaxDelay(100*time.Millisecond), retry.MaxDelay(100*time.Millisecond),
retry.Attempts(0), retry.Attempts(0),
) )
ctx.EndInnerProtocol(pst, func(r *radius.Packet) *radius.Packet { ctx.EndInnerProtocol(pst)
ctx.Log().Debug("TLS: Adding MPPE Keys")
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
return r
})
return nil return nil
} }
return p.startChunkedTransfer(p.st.Conn.OutboundData()) return p.startChunkedTransfer(p.st.Conn.OutboundData())
} }
func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error {
if r.Code != radius.CodeAccessAccept {
return nil
}
if p.st == nil || !p.st.HandshakeDone {
return nil
}
log.Debug("TLS: Adding MPPE Keys")
// TLS overrides other protocols' MPPE keys
if len(microsoft.MSMPPERecvKey_Get(r, q)) > 0 {
microsoft.MSMPPERecvKey_Del(r)
}
if len(microsoft.MSMPPESendKey_Get(r, q)) > 0 {
microsoft.MSMPPESendKey_Del(r)
}
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
return nil
}
func (p *Payload) tlsInit(ctx protocol.Context) { func (p *Payload) tlsInit(ctx protocol.Context) {
ctx.Log().Debug("TLS: no TLS connection in state yet, starting connection") ctx.Log().Debug("TLS: no TLS connection in state yet, starting connection")
p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
@ -181,9 +196,7 @@ func (p *Payload) tlsInit(ctx protocol.Context) {
if err != nil { if err != nil {
ctx.Log().WithError(err).Debug("TLS: Handshake error") ctx.Log().WithError(err).Debug("TLS: Handshake error")
p.st.FinalStatus = protocol.StatusError p.st.FinalStatus = protocol.StatusError
ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet { ctx.EndInnerProtocol(protocol.StatusError)
return p
})
return return
} }
ctx.Log().Debug("TLS: handshake done") ctx.Log().Debug("TLS: handshake done")

View File

@ -153,10 +153,7 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
return protocol.Settings{ return protocol.Settings{
Protocols: append(protocols, tls.Protocol, peap.Protocol), Protocols: append(protocols, tls.Protocol, peap.Protocol),
ProtocolPriority: []protocol.Type{ ProtocolPriority: []protocol.Type{tls.TypeTLS, peap.TypePEAP},
tls.TypeTLS,
peap.TypePEAP,
},
ProtocolSettings: map[protocol.Type]interface{}{ ProtocolSettings: map[protocol.Type]interface{}{
tls.TypeTLS: tls.Settings{ tls.TypeTLS: tls.Settings{
Config: &ttls.Config{ Config: &ttls.Config{
@ -197,6 +194,13 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
InnerProtocols: protocol.Settings{ InnerProtocols: protocol.Settings{
Protocols: append(protocols, mschapv2.Protocol), Protocols: append(protocols, mschapv2.Protocol),
ProtocolPriority: []protocol.Type{mschapv2.TypeMSCHAPv2}, ProtocolPriority: []protocol.Type{mschapv2.TypeMSCHAPv2},
ProtocolSettings: map[protocol.Type]interface{}{
mschapv2.TypeMSCHAPv2: mschapv2.Settings{
AuthenticateRequest: mschapv2.DebugStaticCredentials(
[]byte("foo"), []byte("bar"),
),
},
},
}, },
}, },
}, },