start reworking response modification

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-24 19:49:41 +02:00
parent 9045f5ba73
commit f5eb827d14
11 changed files with 103 additions and 87 deletions

View File

@ -16,8 +16,7 @@ type context struct {
settings interface{}
parent *context
endStatus protocol.Status
endModifier func(p *radius.Packet) *radius.Packet
handleInner func(protocol.Payload, protocol.StateManager) (protocol.Payload, error)
handleInner func(protocol.Payload, protocol.StateManager, protocol.Context) (protocol.Payload, error)
}
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
@ -28,13 +27,10 @@ func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p]
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
func (ctx *context) Log() *log.Entry { return ctx.log }
func (ctx *context) HandleInnerEAP(p protocol.Payload, st protocol.StateManager) (protocol.Payload, error) {
return ctx.handleInner(p, st)
return ctx.handleInner(p, st, ctx)
}
func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radius.Packet) *radius.Packet) protocol.Context {
if ctx.endModifier == nil {
ctx.endModifier = pmf
}
return &context{
func (ctx *context) Inner(p protocol.Payload, t protocol.Type) protocol.Context {
nctx := &context{
req: ctx.req,
rootPayload: ctx.rootPayload,
typeState: ctx.typeState,
@ -43,29 +39,17 @@ func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radiu
parent: ctx,
handleInner: ctx.handleInner,
}
nctx.log.Debug("Creating inner context")
return nctx
}
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
func (ctx *context) EndInnerProtocol(st protocol.Status) {
ctx.log.Info("Ending protocol")
if ctx.parent != nil {
ctx.parent.EndInnerProtocol(st, mf)
ctx.parent.EndInnerProtocol(st)
return
}
if ctx.endStatus != protocol.StatusUnknown {
return
}
ctx.endStatus = st
if mf != nil {
ctx.endModifier = mf
}
}
func (ctx *context) callEndModifier(p *radius.Packet) *radius.Packet {
if ctx.parent != nil {
p = ctx.parent.callEndModifier(p)
}
if ctx.endModifier != nil {
ctx.log.Debug("Running end modifier")
p = ctx.endModifier(p)
}
return p
}

View File

@ -39,7 +39,6 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
rres := r.Response(radius.CodeAccessReject)
if err == nil {
rres = p.endModifier(rres)
switch rp.eap.Code {
case protocol.CodeRequest:
rres.Code = radius.CodeAccessChallenge
@ -52,6 +51,13 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
rres.Code = radius.CodeAccessReject
log.WithError(err).Debug("Rejecting request")
}
for _, mod := range p.responseModifiers {
err := mod.ModifyRADIUSResponse(rres, r.Packet)
if err != nil {
log.WithError(err).Warning("Root-EAP: failed to modify response packet")
break
}
}
rfc2865.State_SetString(rres, p.state)
eapEncoded, err := rp.Encode()
@ -106,7 +112,8 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
var ctx *context
if parentContext != nil {
ctx = parentContext.Inner(np, t, nil).(*context)
ctx = parentContext.Inner(np, t).(*context)
ctx.settings = stm.GetEAPSettings().ProtocolSettings[np.Type()]
} else {
ctx = &context{
req: p.r,
@ -115,8 +122,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t),
settings: stm.GetEAPSettings().ProtocolSettings[t],
}
ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager) (protocol.Payload, error) {
return p.handleEAP(pp, sm, ctx.Inner(pp, pp.Type(), nil).(*context))
ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager, ctx protocol.Context) (protocol.Payload, error) {
// cctx := ctx.Inner(np, np.Type(), nil).(*context)
return p.handleEAP(pp, sm, ctx.(*context))
}
}
if !np.Offerable() {
@ -141,8 +149,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
stm.SetEAPState(p.state, st)
if ctx.endModifier != nil {
p.endModifier = ctx.callEndModifier
if rm, ok := np.(protocol.ResponseModifier); ok {
ctx.log.Debug("Root-EAP: Registered response modifier")
p.responseModifiers = append(p.responseModifiers, rm)
}
switch ctx.endStatus {

View File

@ -11,7 +11,7 @@ type Packet struct {
eap *eap.Payload
stm protocol.StateManager
state string
endModifier func(p *radius.Packet) *radius.Packet
responseModifiers []protocol.ResponseModifier
}
func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
@ -20,9 +20,7 @@ func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
Settings: stm.GetEAPSettings(),
},
stm: stm,
endModifier: func(p *radius.Packet) *radius.Packet {
return p
},
responseModifiers: []protocol.ResponseModifier{},
}
err := packet.eap.Decode(raw)
if err != nil {

View File

@ -25,8 +25,8 @@ type Context interface {
IsProtocolStart(p Type) bool
HandleInnerEAP(Payload, StateManager) (Payload, error)
Inner(Payload, Type, func(p *radius.Packet) *radius.Packet) Context
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
Inner(Payload, Type) Context
EndInnerProtocol(Status)
Log() *log.Entry
}

View File

@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) {
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeIdentity) {
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
ctx.EndInnerProtocol(protocol.StatusNextProtocol)
}
return nil
}

View File

@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) {
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeLegacyNAK) {
ctx.EndInnerProtocol(protocol.StatusError, nil)
ctx.EndInnerProtocol(protocol.StatusError)
}
return nil
}

View File

@ -160,20 +160,29 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
p.st.IsProtocolEnded = true
return ep
} else if p.st.IsProtocolEnded {
ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet {
if len(microsoft.MSMPPERecvKey_Get(r, ctx.Packet().Packet)) < 1 {
microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey)
}
if len(microsoft.MSMPPESendKey_Get(r, ctx.Packet().Packet)) < 1 {
microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey)
}
return r
})
ctx.EndInnerProtocol(protocol.StatusSuccess)
return &Payload{}
}
return response
}
func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error {
if p.st == nil || p.st.AuthResponse == nil {
return nil
}
if r.Code != radius.CodeAccessAccept {
return nil
}
log.Debug("MSCHAPv2: Radius modifier")
if len(microsoft.MSMPPERecvKey_Get(r, q)) < 1 {
microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey)
}
if len(microsoft.MSMPPESendKey_Get(r, q)) < 1 {
microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey)
}
return nil
}
func (p *Payload) Offerable() bool {
return true
}

View File

@ -1,5 +1,18 @@
package protocol
import "layeh.com/radius"
type Type uint8
type Code uint8
const (
CodeRequest Code = 1
CodeResponse Code = 2
CodeSuccess Code = 3
CodeFailure Code = 4
)
type Payload interface {
Decode(raw []byte) error
Encode() ([]byte, error)
@ -13,13 +26,6 @@ type Inner interface {
HasInner() Payload
}
type Type uint8
type Code uint8
const (
CodeRequest Code = 1
CodeResponse Code = 2
CodeSuccess Code = 3
CodeFailure Code = 4
)
type ResponseModifier interface {
ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error
}

View File

@ -2,8 +2,6 @@ package tls
import (
"goauthentik.io/internal/outpost/radius/eap/protocol"
"layeh.com/radius"
"layeh.com/radius/vendors/microsoft"
)
func (p *Payload) innerHandler(ctx protocol.Context) {
@ -13,7 +11,7 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
n, err := p.st.TLS.Read(d)
if err != nil {
ctx.Log().WithError(err).Warning("TLS: Failed to read from TLS connection")
ctx.EndInnerProtocol(protocol.StatusError, nil)
ctx.EndInnerProtocol(protocol.StatusError)
return
}
// Truncate data to the size we read
@ -22,25 +20,20 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
err := p.Inner.Decode(d)
if err != nil {
ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol")
ctx.EndInnerProtocol(protocol.StatusError, nil)
ctx.EndInnerProtocol(protocol.StatusError)
return
}
pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type(), func(r *radius.Packet) *radius.Packet {
ctx.Log().Debug("TLS: Adding MPPE Keys")
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
return r
}))
pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type()))
enc, err := pl.Encode()
if err != nil {
ctx.Log().WithError(err).Warning("TLS: failed to encode inner protocol")
ctx.EndInnerProtocol(protocol.StatusError, nil)
ctx.EndInnerProtocol(protocol.StatusError)
return
}
_, err = p.st.TLS.Write(enc)
if err != nil {
ctx.Log().WithError(err).Warning("TLS: failed to write to TLS")
ctx.EndInnerProtocol(protocol.StatusError, nil)
ctx.EndInnerProtocol(protocol.StatusError)
return
}
}

View File

@ -144,17 +144,32 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
retry.MaxDelay(100*time.Millisecond),
retry.Attempts(0),
)
ctx.EndInnerProtocol(pst, func(r *radius.Packet) *radius.Packet {
ctx.Log().Debug("TLS: Adding MPPE Keys")
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
return r
})
ctx.EndInnerProtocol(pst)
return nil
}
return p.startChunkedTransfer(p.st.Conn.OutboundData())
}
func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error {
if r.Code != radius.CodeAccessAccept {
return nil
}
if p.st == nil || !p.st.HandshakeDone {
return nil
}
log.Debug("TLS: Adding MPPE Keys")
// TLS overrides other protocols' MPPE keys
if len(microsoft.MSMPPERecvKey_Get(r, q)) > 0 {
microsoft.MSMPPERecvKey_Del(r)
}
if len(microsoft.MSMPPESendKey_Get(r, q)) > 0 {
microsoft.MSMPPESendKey_Del(r)
}
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
return nil
}
func (p *Payload) tlsInit(ctx protocol.Context) {
ctx.Log().Debug("TLS: no TLS connection in state yet, starting connection")
p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
@ -181,9 +196,7 @@ func (p *Payload) tlsInit(ctx protocol.Context) {
if err != nil {
ctx.Log().WithError(err).Debug("TLS: Handshake error")
p.st.FinalStatus = protocol.StatusError
ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet {
return p
})
ctx.EndInnerProtocol(protocol.StatusError)
return
}
ctx.Log().Debug("TLS: handshake done")

View File

@ -153,10 +153,7 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
return protocol.Settings{
Protocols: append(protocols, tls.Protocol, peap.Protocol),
ProtocolPriority: []protocol.Type{
tls.TypeTLS,
peap.TypePEAP,
},
ProtocolPriority: []protocol.Type{tls.TypeTLS, peap.TypePEAP},
ProtocolSettings: map[protocol.Type]interface{}{
tls.TypeTLS: tls.Settings{
Config: &ttls.Config{
@ -197,6 +194,13 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
InnerProtocols: protocol.Settings{
Protocols: append(protocols, mschapv2.Protocol),
ProtocolPriority: []protocol.Type{mschapv2.TypeMSCHAPv2},
ProtocolSettings: map[protocol.Type]interface{}{
mschapv2.TypeMSCHAPv2: mschapv2.Settings{
AuthenticateRequest: mschapv2.DebugStaticCredentials(
[]byte("foo"), []byte("bar"),
),
},
},
},
},
},