start reworking response modification
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -16,8 +16,7 @@ type context struct { | ||||
| 	settings    interface{} | ||||
| 	parent      *context | ||||
| 	endStatus   protocol.Status | ||||
| 	endModifier func(p *radius.Packet) *radius.Packet | ||||
| 	handleInner func(protocol.Payload, protocol.StateManager) (protocol.Payload, error) | ||||
| 	handleInner func(protocol.Payload, protocol.StateManager, protocol.Context) (protocol.Payload, error) | ||||
| } | ||||
|  | ||||
| func (ctx *context) RootPayload() protocol.Payload            { return ctx.rootPayload } | ||||
| @ -28,13 +27,10 @@ func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p] | ||||
| func (ctx *context) IsProtocolStart(p protocol.Type) bool     { return ctx.typeState[p] == nil } | ||||
| func (ctx *context) Log() *log.Entry                          { return ctx.log } | ||||
| func (ctx *context) HandleInnerEAP(p protocol.Payload, st protocol.StateManager) (protocol.Payload, error) { | ||||
| 	return ctx.handleInner(p, st) | ||||
| 	return ctx.handleInner(p, st, ctx) | ||||
| } | ||||
| func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radius.Packet) *radius.Packet) protocol.Context { | ||||
| 	if ctx.endModifier == nil { | ||||
| 		ctx.endModifier = pmf | ||||
| 	} | ||||
| 	return &context{ | ||||
| func (ctx *context) Inner(p protocol.Payload, t protocol.Type) protocol.Context { | ||||
| 	nctx := &context{ | ||||
| 		req:         ctx.req, | ||||
| 		rootPayload: ctx.rootPayload, | ||||
| 		typeState:   ctx.typeState, | ||||
| @ -43,29 +39,17 @@ func (ctx *context) Inner(p protocol.Payload, t protocol.Type, pmf func(p *radiu | ||||
| 		parent:      ctx, | ||||
| 		handleInner: ctx.handleInner, | ||||
| 	} | ||||
| 	nctx.log.Debug("Creating inner context") | ||||
| 	return nctx | ||||
| } | ||||
| func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { | ||||
| func (ctx *context) EndInnerProtocol(st protocol.Status) { | ||||
| 	ctx.log.Info("Ending protocol") | ||||
| 	if ctx.parent != nil { | ||||
| 		ctx.parent.EndInnerProtocol(st, mf) | ||||
| 		ctx.parent.EndInnerProtocol(st) | ||||
| 		return | ||||
| 	} | ||||
| 	if ctx.endStatus != protocol.StatusUnknown { | ||||
| 		return | ||||
| 	} | ||||
| 	ctx.endStatus = st | ||||
| 	if mf != nil { | ||||
| 		ctx.endModifier = mf | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ctx *context) callEndModifier(p *radius.Packet) *radius.Packet { | ||||
| 	if ctx.parent != nil { | ||||
| 		p = ctx.parent.callEndModifier(p) | ||||
| 	} | ||||
| 	if ctx.endModifier != nil { | ||||
| 		ctx.log.Debug("Running end modifier") | ||||
| 		p = ctx.endModifier(p) | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|  | ||||
| @ -39,7 +39,6 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) | ||||
|  | ||||
| 	rres := r.Response(radius.CodeAccessReject) | ||||
| 	if err == nil { | ||||
| 		rres = p.endModifier(rres) | ||||
| 		switch rp.eap.Code { | ||||
| 		case protocol.CodeRequest: | ||||
| 			rres.Code = radius.CodeAccessChallenge | ||||
| @ -52,6 +51,13 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) | ||||
| 		rres.Code = radius.CodeAccessReject | ||||
| 		log.WithError(err).Debug("Rejecting request") | ||||
| 	} | ||||
| 	for _, mod := range p.responseModifiers { | ||||
| 		err := mod.ModifyRADIUSResponse(rres, r.Packet) | ||||
| 		if err != nil { | ||||
| 			log.WithError(err).Warning("Root-EAP: failed to modify response packet") | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	rfc2865.State_SetString(rres, p.state) | ||||
| 	eapEncoded, err := rp.Encode() | ||||
| @ -106,7 +112,8 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren | ||||
|  | ||||
| 	var ctx *context | ||||
| 	if parentContext != nil { | ||||
| 		ctx = parentContext.Inner(np, t, nil).(*context) | ||||
| 		ctx = parentContext.Inner(np, t).(*context) | ||||
| 		ctx.settings = stm.GetEAPSettings().ProtocolSettings[np.Type()] | ||||
| 	} else { | ||||
| 		ctx = &context{ | ||||
| 			req:         p.r, | ||||
| @ -115,8 +122,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren | ||||
| 			log:         log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t), | ||||
| 			settings:    stm.GetEAPSettings().ProtocolSettings[t], | ||||
| 		} | ||||
| 		ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager) (protocol.Payload, error) { | ||||
| 			return p.handleEAP(pp, sm, ctx.Inner(pp, pp.Type(), nil).(*context)) | ||||
| 		ctx.handleInner = func(pp protocol.Payload, sm protocol.StateManager, ctx protocol.Context) (protocol.Payload, error) { | ||||
| 			// cctx := ctx.Inner(np, np.Type(), nil).(*context) | ||||
| 			return p.handleEAP(pp, sm, ctx.(*context)) | ||||
| 		} | ||||
| 	} | ||||
| 	if !np.Offerable() { | ||||
| @ -141,8 +149,9 @@ func (p *Packet) handleEAP(pp protocol.Payload, stm protocol.StateManager, paren | ||||
|  | ||||
| 	stm.SetEAPState(p.state, st) | ||||
|  | ||||
| 	if ctx.endModifier != nil { | ||||
| 		p.endModifier = ctx.callEndModifier | ||||
| 	if rm, ok := np.(protocol.ResponseModifier); ok { | ||||
| 		ctx.log.Debug("Root-EAP: Registered response modifier") | ||||
| 		p.responseModifiers = append(p.responseModifiers, rm) | ||||
| 	} | ||||
|  | ||||
| 	switch ctx.endStatus { | ||||
|  | ||||
| @ -7,11 +7,11 @@ import ( | ||||
| ) | ||||
|  | ||||
| type Packet struct { | ||||
| 	r           *radius.Request | ||||
| 	eap         *eap.Payload | ||||
| 	stm         protocol.StateManager | ||||
| 	state       string | ||||
| 	endModifier func(p *radius.Packet) *radius.Packet | ||||
| 	r                 *radius.Request | ||||
| 	eap               *eap.Payload | ||||
| 	stm               protocol.StateManager | ||||
| 	state             string | ||||
| 	responseModifiers []protocol.ResponseModifier | ||||
| } | ||||
|  | ||||
| func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) { | ||||
| @ -19,10 +19,8 @@ func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) { | ||||
| 		eap: &eap.Payload{ | ||||
| 			Settings: stm.GetEAPSettings(), | ||||
| 		}, | ||||
| 		stm: stm, | ||||
| 		endModifier: func(p *radius.Packet) *radius.Packet { | ||||
| 			return p | ||||
| 		}, | ||||
| 		stm:               stm, | ||||
| 		responseModifiers: []protocol.ResponseModifier{}, | ||||
| 	} | ||||
| 	err := packet.eap.Decode(raw) | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -25,8 +25,8 @@ type Context interface { | ||||
| 	IsProtocolStart(p Type) bool | ||||
|  | ||||
| 	HandleInnerEAP(Payload, StateManager) (Payload, error) | ||||
| 	Inner(Payload, Type, func(p *radius.Packet) *radius.Packet) Context | ||||
| 	EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) | ||||
| 	Inner(Payload, Type) Context | ||||
| 	EndInnerProtocol(Status) | ||||
|  | ||||
| 	Log() *log.Entry | ||||
| } | ||||
|  | ||||
| @ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) { | ||||
|  | ||||
| func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { | ||||
| 	if ctx.IsProtocolStart(TypeIdentity) { | ||||
| 		ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil) | ||||
| 		ctx.EndInnerProtocol(protocol.StatusNextProtocol) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @ -31,7 +31,7 @@ func (p *Payload) Encode() ([]byte, error) { | ||||
|  | ||||
| func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { | ||||
| 	if ctx.IsProtocolStart(TypeLegacyNAK) { | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError, nil) | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @ -160,20 +160,29 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { | ||||
| 		p.st.IsProtocolEnded = true | ||||
| 		return ep | ||||
| 	} else if p.st.IsProtocolEnded { | ||||
| 		ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet { | ||||
| 			if len(microsoft.MSMPPERecvKey_Get(r, ctx.Packet().Packet)) < 1 { | ||||
| 				microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey) | ||||
| 			} | ||||
| 			if len(microsoft.MSMPPESendKey_Get(r, ctx.Packet().Packet)) < 1 { | ||||
| 				microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey) | ||||
| 			} | ||||
| 			return r | ||||
| 		}) | ||||
| 		ctx.EndInnerProtocol(protocol.StatusSuccess) | ||||
| 		return &Payload{} | ||||
| 	} | ||||
| 	return response | ||||
| } | ||||
|  | ||||
| func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error { | ||||
| 	if p.st == nil || p.st.AuthResponse == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if r.Code != radius.CodeAccessAccept { | ||||
| 		return nil | ||||
| 	} | ||||
| 	log.Debug("MSCHAPv2: Radius modifier") | ||||
| 	if len(microsoft.MSMPPERecvKey_Get(r, q)) < 1 { | ||||
| 		microsoft.MSMPPERecvKey_Set(r, p.st.AuthResponse.RecvKey) | ||||
| 	} | ||||
| 	if len(microsoft.MSMPPESendKey_Get(r, q)) < 1 { | ||||
| 		microsoft.MSMPPESendKey_Set(r, p.st.AuthResponse.SendKey) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *Payload) Offerable() bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| @ -1,5 +1,18 @@ | ||||
| package protocol | ||||
|  | ||||
| import "layeh.com/radius" | ||||
|  | ||||
| type Type uint8 | ||||
|  | ||||
| type Code uint8 | ||||
|  | ||||
| const ( | ||||
| 	CodeRequest  Code = 1 | ||||
| 	CodeResponse Code = 2 | ||||
| 	CodeSuccess  Code = 3 | ||||
| 	CodeFailure  Code = 4 | ||||
| ) | ||||
|  | ||||
| type Payload interface { | ||||
| 	Decode(raw []byte) error | ||||
| 	Encode() ([]byte, error) | ||||
| @ -13,13 +26,6 @@ type Inner interface { | ||||
| 	HasInner() Payload | ||||
| } | ||||
|  | ||||
| type Type uint8 | ||||
|  | ||||
| type Code uint8 | ||||
|  | ||||
| const ( | ||||
| 	CodeRequest  Code = 1 | ||||
| 	CodeResponse Code = 2 | ||||
| 	CodeSuccess  Code = 3 | ||||
| 	CodeFailure  Code = 4 | ||||
| ) | ||||
| type ResponseModifier interface { | ||||
| 	ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error | ||||
| } | ||||
|  | ||||
| @ -2,8 +2,6 @@ package tls | ||||
|  | ||||
| import ( | ||||
| 	"goauthentik.io/internal/outpost/radius/eap/protocol" | ||||
| 	"layeh.com/radius" | ||||
| 	"layeh.com/radius/vendors/microsoft" | ||||
| ) | ||||
|  | ||||
| func (p *Payload) innerHandler(ctx protocol.Context) { | ||||
| @ -13,7 +11,7 @@ func (p *Payload) innerHandler(ctx protocol.Context) { | ||||
| 		n, err := p.st.TLS.Read(d) | ||||
| 		if err != nil { | ||||
| 			ctx.Log().WithError(err).Warning("TLS: Failed to read from TLS connection") | ||||
| 			ctx.EndInnerProtocol(protocol.StatusError, nil) | ||||
| 			ctx.EndInnerProtocol(protocol.StatusError) | ||||
| 			return | ||||
| 		} | ||||
| 		// Truncate data to the size we read | ||||
| @ -22,25 +20,20 @@ func (p *Payload) innerHandler(ctx protocol.Context) { | ||||
| 	err := p.Inner.Decode(d) | ||||
| 	if err != nil { | ||||
| 		ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol") | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError, nil) | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError) | ||||
| 		return | ||||
| 	} | ||||
| 	pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type(), func(r *radius.Packet) *radius.Packet { | ||||
| 		ctx.Log().Debug("TLS: Adding MPPE Keys") | ||||
| 		microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) | ||||
| 		microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) | ||||
| 		return r | ||||
| 	})) | ||||
| 	pl := p.Inner.Handle(ctx.Inner(p.Inner, p.Inner.Type())) | ||||
| 	enc, err := pl.Encode() | ||||
| 	if err != nil { | ||||
| 		ctx.Log().WithError(err).Warning("TLS: failed to encode inner protocol") | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError, nil) | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError) | ||||
| 		return | ||||
| 	} | ||||
| 	_, err = p.st.TLS.Write(enc) | ||||
| 	if err != nil { | ||||
| 		ctx.Log().WithError(err).Warning("TLS: failed to write to TLS") | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError, nil) | ||||
| 		ctx.EndInnerProtocol(protocol.StatusError) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -144,17 +144,32 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { | ||||
| 			retry.MaxDelay(100*time.Millisecond), | ||||
| 			retry.Attempts(0), | ||||
| 		) | ||||
| 		ctx.EndInnerProtocol(pst, func(r *radius.Packet) *radius.Packet { | ||||
| 			ctx.Log().Debug("TLS: Adding MPPE Keys") | ||||
| 			microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) | ||||
| 			microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) | ||||
| 			return r | ||||
| 		}) | ||||
| 		ctx.EndInnerProtocol(pst) | ||||
| 		return nil | ||||
| 	} | ||||
| 	return p.startChunkedTransfer(p.st.Conn.OutboundData()) | ||||
| } | ||||
|  | ||||
| func (p *Payload) ModifyRADIUSResponse(r *radius.Packet, q *radius.Packet) error { | ||||
| 	if r.Code != radius.CodeAccessAccept { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if p.st == nil || !p.st.HandshakeDone { | ||||
| 		return nil | ||||
| 	} | ||||
| 	log.Debug("TLS: Adding MPPE Keys") | ||||
| 	// TLS overrides other protocols' MPPE keys | ||||
| 	if len(microsoft.MSMPPERecvKey_Get(r, q)) > 0 { | ||||
| 		microsoft.MSMPPERecvKey_Del(r) | ||||
| 	} | ||||
| 	if len(microsoft.MSMPPESendKey_Get(r, q)) > 0 { | ||||
| 		microsoft.MSMPPESendKey_Del(r) | ||||
| 	} | ||||
| 	microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) | ||||
| 	microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *Payload) tlsInit(ctx protocol.Context) { | ||||
| 	ctx.Log().Debug("TLS: no TLS connection in state yet, starting connection") | ||||
| 	p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) | ||||
| @ -181,9 +196,7 @@ func (p *Payload) tlsInit(ctx protocol.Context) { | ||||
| 		if err != nil { | ||||
| 			ctx.Log().WithError(err).Debug("TLS: Handshake error") | ||||
| 			p.st.FinalStatus = protocol.StatusError | ||||
| 			ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet { | ||||
| 				return p | ||||
| 			}) | ||||
| 			ctx.EndInnerProtocol(protocol.StatusError) | ||||
| 			return | ||||
| 		} | ||||
| 		ctx.Log().Debug("TLS: handshake done") | ||||
|  | ||||
| @ -153,10 +153,7 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings { | ||||
|  | ||||
| 	return protocol.Settings{ | ||||
| 		Protocols: append(protocols, tls.Protocol, peap.Protocol), | ||||
| 		ProtocolPriority: []protocol.Type{ | ||||
| 			tls.TypeTLS, | ||||
| 			peap.TypePEAP, | ||||
| 		}, | ||||
| 		ProtocolPriority: []protocol.Type{tls.TypeTLS, peap.TypePEAP}, | ||||
| 		ProtocolSettings: map[protocol.Type]interface{}{ | ||||
| 			tls.TypeTLS: tls.Settings{ | ||||
| 				Config: &ttls.Config{ | ||||
| @ -197,6 +194,13 @@ func (pi *ProviderInstance) GetEAPSettings() protocol.Settings { | ||||
| 				InnerProtocols: protocol.Settings{ | ||||
| 					Protocols:        append(protocols, mschapv2.Protocol), | ||||
| 					ProtocolPriority: []protocol.Type{mschapv2.TypeMSCHAPv2}, | ||||
| 					ProtocolSettings: map[protocol.Type]interface{}{ | ||||
| 						mschapv2.TypeMSCHAPv2: mschapv2.Settings{ | ||||
| 							AuthenticateRequest: mschapv2.DebugStaticCredentials( | ||||
| 								[]byte("foo"), []byte("bar"), | ||||
| 							), | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer