Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-16 14:32:47 +02:00
parent ac88784089
commit 318443f270
5 changed files with 38 additions and 31 deletions

View File

@ -1,29 +1,32 @@
package eap package eap
import ( import (
"github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/protocol"
"layeh.com/radius" "layeh.com/radius"
) )
type context[TState any, TSettings any] struct { type context struct {
state interface{}
log *log.Entry
} }
func (ctx context[TState, TSettings]) ProtocolSettings() TSettings { func (ctx context) ProtocolSettings() interface{} {
return 0
}
func (ctx context[TState, TSettings]) GetProtocolState(def func(context[TState, TSettings]) TState) TState {
return nil return nil
} }
func (ctx context[TState, TSettings]) SetProtocolState(TState) { func (ctx context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
return ctx.state
}
func (ctx context) SetProtocolState(st interface{}) {
ctx.state = st
}
func (ctx context) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) {
} }
func (ctx context[TState, TSettings]) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) { func (ctx context) Log() *log.Entry {
return ctx.log
}
func (ctx context[TState, TSettings]) Log() *logrus.Entry {
return nil
} }

View File

@ -29,10 +29,14 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
} }
nextChallengeToOffer := st.ChallengesToOffer[0] nextChallengeToOffer := st.ChallengesToOffer[0]
ctx := context{} ctx := context{
state: st.TypeState[nextChallengeToOffer],
log: log.WithField("type", nextChallengeToOffer),
}
res, newState := p.GetChallengeForType(ctx, nextChallengeToOffer) res := p.GetChallengeForType(ctx, nextChallengeToOffer)
stm.SetEAPState(rst, newState) st.TypeState[nextChallengeToOffer] = ctx.GetProtocolState(nil)
stm.SetEAPState(rst, st)
rres := r.Response(radius.CodeAccessChallenge) rres := r.Response(radius.CodeAccessChallenge)
if p, ok := res.Payload.(protocol.EmptyPayload); ok { if p, ok := res.Payload.(protocol.EmptyPayload); ok {
@ -55,7 +59,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
} }
} }
func (p *Packet) GetChallengeForType(ctx context[any, any], t Type) *Packet { func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
res := &Packet{ res := &Packet{
code: CodeRequest, code: CodeRequest,
id: p.id + 1, id: p.id + 1,
@ -72,9 +76,9 @@ func (p *Packet) GetChallengeForType(ctx context[any, any], t Type) *Packet {
// this // this
payload = p.Payload.(*tls.Payload).Handle(ctx) payload = p.Payload.(*tls.Payload).Handle(ctx)
} }
st.TypeState[t] = tst // st.TypeState[t] = tst
res.Payload = payload.(protocol.Payload) res.Payload = payload.(protocol.Payload)
return res, st return res
} }
func (p *Packet) setMessageAuthenticator(rp *radius.Packet) { func (p *Packet) setMessageAuthenticator(rp *radius.Packet) {

View File

@ -5,12 +5,12 @@ import (
"layeh.com/radius" "layeh.com/radius"
) )
type Context[TState any, TSettings any] interface { type Context interface {
// GlobalState() // GlobalState()
ProtocolSettings() TSettings ProtocolSettings() interface{}
GetProtocolState(def func(Context[TState, TSettings]) TState) TState GetProtocolState(def func(Context) interface{}) interface{}
SetProtocolState(TState) SetProtocolState(interface{})
EndInnerProtocol(func(p *radius.Packet) *radius.Packet) EndInnerProtocol(func(p *radius.Packet) *radius.Packet)

View File

@ -61,10 +61,8 @@ func (p *Payload) Encode() ([]byte, error) {
return buff, nil return buff, nil
} }
type tctx = protocol.Context[*State, Settings] func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
p.st = ctx.GetProtocolState(NewState).(*State)
func (p *Payload) Handle(ctx tctx) protocol.Payload {
p.st = ctx.GetProtocolState(NewState)
defer ctx.SetProtocolState(p.st) defer ctx.SetProtocolState(p.st)
if !p.st.HasStarted { if !p.st.HasStarted {
log.Debug("TLS: handshake starting") log.Debug("TLS: handshake starting")
@ -115,11 +113,11 @@ func (p *Payload) Handle(ctx tctx) protocol.Payload {
return p.startChunkedTransfer(p.st.Conn.OutboundData()) return p.startChunkedTransfer(p.st.Conn.OutboundData())
} }
func (p *Payload) tlsInit(ctx tctx) { func (p *Payload) tlsInit(ctx protocol.Context) {
log.Debug("TLS: no TLS connection in state yet, starting connection") log.Debug("TLS: no TLS connection in state yet, starting connection")
p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
p.st.Conn = NewBuffConn(p.Data, p.st.Context) p.st.Conn = NewBuffConn(p.Data, p.st.Context)
cfg := ctx.ProtocolSettings().Config.Clone() cfg := ctx.ProtocolSettings().(Settings).Config.Clone()
cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
log.Debugf("TLS: ClientHello: %+v\n", chi) log.Debugf("TLS: ClientHello: %+v\n", chi)
p.st.ClientHello = chi p.st.ClientHello = chi

View File

@ -3,6 +3,8 @@ package tls
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"goauthentik.io/internal/outpost/radius/eap/protocol"
) )
type State struct { type State struct {
@ -18,7 +20,7 @@ type State struct {
ContextCancel context.CancelFunc ContextCancel context.CancelFunc
} }
func NewState(c tctx) *State { func NewState(c protocol.Context) interface{} {
c.Log().Debug("TLS: new state") c.Log().Debug("TLS: new state")
return &State{ return &State{
RemainingChunks: make([][]byte, 0), RemainingChunks: make([][]byte, 0),