start reworking response modification
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -16,8 +16,7 @@ type context struct {
|
||||
settings interface{}
|
||||
parent *context
|
||||
endStatus protocol.Status
|
||||
endModifier func(p *radius.Packet) *radius.Packet
|
||||
handleInner func(protocol.Payload, protocol.StateManager) (protocol.Payload, error)
|
||||
handleInner func(protocol.Payload, protocol.StateManager, protocol.Context) (protocol.Payload, error)
|
||||
}
|
||||
|
||||
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
|
||||
@ -28,13 +27,10 @@ func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p]
|
||||
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
|
||||
func (ctx *context) Log() *log.Entry { return ctx.log }
|
||||
func (ctx *context) HandleInnerEAP(p protocol.Payload, st protocol.StateManager) (protocol.Payload, error) {
|
||||
return ctx.handleInner(p, st)
|
||||
return ctx.handleInner(p, st, ctx)
|
||||
}
|
||||
func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radius.Packet) *radius.Packet) protocol.Context {
|
||||
if ctx.endModifier == nil {
|
||||
ctx.endModifier = pmf
|
||||
}
|
||||
return &context{
|
||||
func (ctx *context) Inner(p protocol.Payload, t protocol.Type) protocol.Context {
|
||||
nctx := &context{
|
||||
req: ctx.req,
|
||||
rootPayload: ctx.rootPayload,
|
||||
typeState: ctx.typeState,
|
||||
@ -43,29 +39,17 @@ func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radiu
|
||||
parent: ctx,
|
||||
handleInner: ctx.handleInner,
|
||||
}
|
||||
nctx.log.Debug("Creating inner context")
|
||||
return nctx
|
||||
}
|
||||
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
||||
func (ctx *context) EndInnerProtocol(st protocol.Status) {
|
||||
ctx.log.Info("Ending protocol")
|
||||
if ctx.parent != nil {
|
||||
ctx.parent.EndInnerProtocol(st, mf)
|
||||
ctx.parent.EndInnerProtocol(st)
|
||||
return
|
||||
}
|
||||
if ctx.endStatus != protocol.StatusUnknown {
|
||||
return
|
||||
}
|
||||
ctx.endStatus = st
|
||||
if mf != nil {
|
||||
ctx.endModifier = mf
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *context) callEndModifier(p *radius.Packet) *radius.Packet {
|
||||
if ctx.parent != nil {
|
||||
p = ctx.parent.callEndModifier(p)
|
||||
}
|
||||
if ctx.endModifier != nil {
|
||||
ctx.log.Debug("Running end modifier")
|
||||
p = ctx.endModifier(p)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
@ -39,7 +39,6 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
|
||||
|
||||
rres := r.Response(radius.CodeAccessReject)
|
||||
if err == nil {
|
||||
rres = p.endModifier(rres)
|
||||
switch rp.eap.Code {
|
||||
case protocol.CodeRequest:
|
||||
rres.Code = radius.CodeAccessChallenge
|
||||
@ -52,6 +51,13 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request)
|
||||
rres.Code = radius.CodeAccessReject
|
||||
log.WithError(err).Debug("Rejecting request")
|
||||
}
|
||||
for _, mod := range p.responseModifiers {
|
||||
err := mod.ModifyRADIUSResponse(rres, r.Packet)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("Root-EAP: failed to modify response packet")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
rfc2865.State_SetString(rres, p.state)
|
||||
eapEncoded, err := rp.Encode()
|
||||
@ -106,7 +112,8 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
|
||||
|
||||
var ctx *context
|
||||
if parentContext != nil {
|
||||
ctx = parentContext.Inner(np, t, nil).(*context)
|
||||
ctx = parentContext.Inner(np, t).(*context)
|
||||
ctx.settings = stm.GetEAPSettings().ProtocolSettings[np.Type()]
|
||||
} else {
|
||||
ctx = &context{
|
||||
req: p.r,
|
||||
@ -115,8 +122,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
|
||||
log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t),
|
||||
settings: stm.GetEAPSettings().ProtocolSettings[t],
|
||||
}
|
||||
ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager) (protocol.Payload, error) {
|
||||
return p.handleEAP(pp, sm, ctx.Inner(pp, pp.Type(), nil).(*context))
|
||||
ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager, ctx protocol.Context) (protocol.Payload, error) {
|
||||
// cctx := ctx.Inner(np, np.Type(), nil).(*context)
|
||||
return p.handleEAP(pp, sm, ctx.(*context))
|
||||
}
|
||||
}
|
||||
if !np.Offerable() {
|
||||
@ -141,8 +149,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren
|
||||
|
||||
stm.SetEAPState(p.state, st)
|
||||
|
||||
if ctx.endModifier != nil {
|
||||
p.endModifier = ctx.callEndModifier
|
||||
if rm, ok := np.(protocol.ResponseModifier); ok {
|
||||
ctx.log.Debug("Root-EAP: Registered response modifier")
|
||||
p.responseModifiers = append(p.responseModifiers, rm)
|
||||
}
|
||||
|
||||
switch ctx.endStatus {
|
||||
|
@ -11,7 +11,7 @@ type Packet struct {
|
||||
eap *eap.Payload
|
||||
stm protocol.StateManager
|
||||
state string
|
||||
endModifier func(p *radius.Packet) *radius.Packet
|
||||
responseModifiers []protocol.ResponseModifier
|
||||
}
|
||||
|
||||
func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
|
||||
@ -20,9 +20,7 @@ func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
|
||||
Settings: stm.GetEAPSettings(),
|
||||
},
|
||||
stm: stm,
|
||||
endModifier: func(p *radius.Packet) *radius.Packet {
|
||||
return p
|
||||
},
|
||||
responseModifiers: []protocol.ResponseModifier{},
|
||||
}
|
||||
err := packet.eap.Decode(raw)
|
||||
if err != nil {
|
||||
|
@ -25,8 +25,8 @@ type Context interface {
|
||||
IsProtocolStart(p Type) bool
|
||||
|
||||
HandleInnerEAP(Payload, StateManager) (Payload, error)
|
||||
Inner(Payload, Type, func(p *radius.Packet) *radius.Packet) Context
|
||||
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
||||
Inner(Payload, Type) Context
|
||||
EndInnerProtocol(Status)
|
||||
|
||||
Log() *log.Entry
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) {
|
||||
|
||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart(TypeIdentity) {
|
||||
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
|
||||
ctx.EndInnerProtocol(protocol.StatusNextProtocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) {
|
||||
|
||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart(TypeLegacyNAK) {
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
ctx.EndInnerProtocol(protocol.StatusError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -160,20 +160,29 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
p.st.IsProtocolEnded = true
|
||||
return ep
|
||||
} else if p.st.IsProtocolEnded {
|
||||
ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet {
|
||||
if len(microsoft.MSMPPERecvKey_Get(r, ctx.Packet().Packet)) < 1 {
|
||||
microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey)
|
||||
}
|
||||
if len(microsoft.MSMPPESendKey_Get(r, ctx.Packet().Packet)) < 1 {
|
||||
microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey)
|
||||
}
|
||||
return r
|
||||
})
|
||||
ctx.EndInnerProtocol(protocol.StatusSuccess)
|
||||
return &Payload{}
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error {
|
||||
if p.st == nil || p.st.AuthResponse == nil {
|
||||
return nil
|
||||
}
|
||||
if r.Code != radius.CodeAccessAccept {
|
||||
return nil
|
||||
}
|
||||
log.Debug("MSCHAPv2: Radius modifier")
|
||||
if len(microsoft.MSMPPERecvKey_Get(r, q)) < 1 {
|
||||
microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey)
|
||||
}
|
||||
if len(microsoft.MSMPPESendKey_Get(r, q)) < 1 {
|
||||
microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Payload) Offerable() bool {
|
||||
return true
|
||||
}
|
||||
|
@ -1,5 +1,18 @@
|
||||
package protocol
|
||||
|
||||
import "layeh.com/radius"
|
||||
|
||||
type Type uint8
|
||||
|
||||
type Code uint8
|
||||
|
||||
const (
|
||||
CodeRequest Code = 1
|
||||
CodeResponse Code = 2
|
||||
CodeSuccess Code = 3
|
||||
CodeFailure Code = 4
|
||||
)
|
||||
|
||||
type Payload interface {
|
||||
Decode(raw []byte) error
|
||||
Encode() ([]byte, error)
|
||||
@ -13,13 +26,6 @@ type Inner interface {
|
||||
HasInner() Payload
|
||||
}
|
||||
|
||||
type Type uint8
|
||||
|
||||
type Code uint8
|
||||
|
||||
const (
|
||||
CodeRequest Code = 1
|
||||
CodeResponse Code = 2
|
||||
CodeSuccess Code = 3
|
||||
CodeFailure Code = 4
|
||||
)
|
||||
type ResponseModifier interface {
|
||||
ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error
|
||||
}
|
||||
|
@ -2,8 +2,6 @@ package tls
|
||||
|
||||
import (
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
"layeh.com/radius"
|
||||
"layeh.com/radius/vendors/microsoft"
|
||||
)
|
||||
|
||||
func (p *Payload) innerHandler(ctx protocol.Context) {
|
||||
@ -13,7 +11,7 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
|
||||
n, err := p.st.TLS.Read(d)
|
||||
if err != nil {
|
||||
ctx.Log().WithError(err).Warning("TLS: Failed to read from TLS connection")
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
ctx.EndInnerProtocol(protocol.StatusError)
|
||||
return
|
||||
}
|
||||
// Truncate data to the size we read
|
||||
@ -22,25 +20,20 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
|
||||
err := p.Inner.Decode(d)
|
||||
if err != nil {
|
||||
ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol")
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
ctx.EndInnerProtocol(protocol.StatusError)
|
||||
return
|
||||
}
|
||||
pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type(), func(r *radius.Packet) *radius.Packet {
|
||||
ctx.Log().Debug("TLS: Adding MPPE Keys")
|
||||
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
|
||||
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
|
||||
return r
|
||||
}))
|
||||
pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type()))
|
||||
enc, err := pl.Encode()
|
||||
if err != nil {
|
||||
ctx.Log().WithError(err).Warning("TLS: failed to encode inner protocol")
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
ctx.EndInnerProtocol(protocol.StatusError)
|
||||
return
|
||||
}
|
||||
_, err = p.st.TLS.Write(enc)
|
||||
if err != nil {
|
||||
ctx.Log().WithError(err).Warning("TLS: failed to write to TLS")
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
ctx.EndInnerProtocol(protocol.StatusError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -144,17 +144,32 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
retry.MaxDelay(100*time.Millisecond),
|
||||
retry.Attempts(0),
|
||||
)
|
||||
ctx.EndInnerProtocol(pst, func(r *radius.Packet) *radius.Packet {
|
||||
ctx.Log().Debug("TLS: Adding MPPE Keys")
|
||||
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
|
||||
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
|
||||
return r
|
||||
})
|
||||
ctx.EndInnerProtocol(pst)
|
||||
return nil
|
||||
}
|
||||
return p.startChunkedTransfer(p.st.Conn.OutboundData())
|
||||
}
|
||||
|
||||
func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error {
|
||||
if r.Code != radius.CodeAccessAccept {
|
||||
return nil
|
||||
}
|
||||
if p.st == nil || !p.st.HandshakeDone {
|
||||
return nil
|
||||
}
|
||||
log.Debug("TLS: Adding MPPE Keys")
|
||||
// TLS overrides other protocols' MPPE keys
|
||||
if len(microsoft.MSMPPERecvKey_Get(r, q)) > 0 {
|
||||
microsoft.MSMPPERecvKey_Del(r)
|
||||
}
|
||||
if len(microsoft.MSMPPESendKey_Get(r, q)) > 0 {
|
||||
microsoft.MSMPPESendKey_Del(r)
|
||||
}
|
||||
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
|
||||
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Payload) tlsInit(ctx protocol.Context) {
|
||||
ctx.Log().Debug("TLS: no TLS connection in state yet, starting connection")
|
||||
p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
|
||||
@ -181,9 +196,7 @@ func (p *Payload) tlsInit(ctx protocol.Context) {
|
||||
if err != nil {
|
||||
ctx.Log().WithError(err).Debug("TLS: Handshake error")
|
||||
p.st.FinalStatus = protocol.StatusError
|
||||
ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet {
|
||||
return p
|
||||
})
|
||||
ctx.EndInnerProtocol(protocol.StatusError)
|
||||
return
|
||||
}
|
||||
ctx.Log().Debug("TLS: handshake done")
|
||||
|
@ -153,10 +153,7 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
|
||||
|
||||
return protocol.Settings{
|
||||
Protocols: append(protocols, tls.Protocol, peap.Protocol),
|
||||
ProtocolPriority: []protocol.Type{
|
||||
tls.TypeTLS,
|
||||
peap.TypePEAP,
|
||||
},
|
||||
ProtocolPriority: []protocol.Type{tls.TypeTLS, peap.TypePEAP},
|
||||
ProtocolSettings: map[protocol.Type]interface{}{
|
||||
tls.TypeTLS: tls.Settings{
|
||||
Config: &ttls.Config{
|
||||
@ -197,6 +194,13 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
|
||||
InnerProtocols: protocol.Settings{
|
||||
Protocols: append(protocols, mschapv2.Protocol),
|
||||
ProtocolPriority: []protocol.Type{mschapv2.TypeMSCHAPv2},
|
||||
ProtocolSettings: map[protocol.Type]interface{}{
|
||||
mschapv2.TypeMSCHAPv2: mschapv2.Settings{
|
||||
AuthenticateRequest: mschapv2.DebugStaticCredentials(
|
||||
[]byte("foo"), []byte("bar"),
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
Reference in New Issue
Block a user