Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-16 14:32:47 +02:00
parent ac88784089
commit 318443f270
5 changed files with 38 additions and 31 deletions

View File

@ -1,29 +1,32 @@
package eap
import (
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/protocol"
"layeh.com/radius"
)
type context[TState any, TSettings any] struct {
type context struct {
state interface{}
log *log.Entry
}
func (ctx context[TState, TSettings]) ProtocolSettings() TSettings {
return 0
}
func (ctx context[TState, TSettings]) GetProtocolState(def func(context[TState, TSettings]) TState) TState {
func (ctx context) ProtocolSettings() interface{} {
return nil
}
func (ctx context[TState, TSettings]) SetProtocolState(TState) {
func (ctx context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
return ctx.state
}
func (ctx context) SetProtocolState(st interface{}) {
ctx.state = st
}
func (ctx context) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) {
}
func (ctx context[TState, TSettings]) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) {
}
func (ctx context[TState, TSettings]) Log() *logrus.Entry {
return nil
func (ctx context) Log() *log.Entry {
return ctx.log
}

View File

@ -29,10 +29,14 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
}
nextChallengeToOffer := st.ChallengesToOffer[0]
ctx := context{}
ctx := context{
state: st.TypeState[nextChallengeToOffer],
log: log.WithField("type", nextChallengeToOffer),
}
res, newState := p.GetChallengeForType(ctx, nextChallengeToOffer)
stm.SetEAPState(rst, newState)
res := p.GetChallengeForType(ctx, nextChallengeToOffer)
st.TypeState[nextChallengeToOffer] = ctx.GetProtocolState(nil)
stm.SetEAPState(rst, st)
rres := r.Response(radius.CodeAccessChallenge)
if p, ok := res.Payload.(protocol.EmptyPayload); ok {
@ -55,7 +59,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
}
}
func (p *Packet) GetChallengeForType(ctx context[any, any], t Type) *Packet {
func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
res := &Packet{
code: CodeRequest,
id: p.id + 1,
@ -72,9 +76,9 @@ func (p *Packet) GetChallengeForType(ctx context[any, any], t Type) *Packet {
// this
payload = p.Payload.(*tls.Payload).Handle(ctx)
}
st.TypeState[t] = tst
// st.TypeState[t] = tst
res.Payload = payload.(protocol.Payload)
return res, st
return res
}
func (p *Packet) setMessageAuthenticator(rp *radius.Packet) {

View File

@ -5,12 +5,12 @@ import (
"layeh.com/radius"
)
type Context[TState any, TSettings any] interface {
type Context interface {
// GlobalState()
ProtocolSettings() TSettings
GetProtocolState(def func(Context[TState, TSettings]) TState) TState
SetProtocolState(TState)
ProtocolSettings() interface{}
GetProtocolState(def func(Context) interface{}) interface{}
SetProtocolState(interface{})
EndInnerProtocol(func(p *radius.Packet) *radius.Packet)

View File

@ -61,10 +61,8 @@ func (p *Payload) Encode() ([]byte, error) {
return buff, nil
}
type tctx = protocol.Context[*State, Settings]
func (p *Payload) Handle(ctx tctx) protocol.Payload {
p.st = ctx.GetProtocolState(NewState)
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
p.st = ctx.GetProtocolState(NewState).(*State)
defer ctx.SetProtocolState(p.st)
if !p.st.HasStarted {
log.Debug("TLS: handshake starting")
@ -115,11 +113,11 @@ func (p *Payload) Handle(ctx tctx) protocol.Payload {
return p.startChunkedTransfer(p.st.Conn.OutboundData())
}
func (p *Payload) tlsInit(ctx tctx) {
func (p *Payload) tlsInit(ctx protocol.Context) {
log.Debug("TLS: no TLS connection in state yet, starting connection")
p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
p.st.Conn = NewBuffConn(p.Data, p.st.Context)
cfg := ctx.ProtocolSettings().Config.Clone()
cfg := ctx.ProtocolSettings().(Settings).Config.Clone()
cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
log.Debugf("TLS: ClientHello: %+v\n", chi)
p.st.ClientHello = chi

View File

@ -3,6 +3,8 @@ package tls
import (
"context"
"crypto/tls"
"goauthentik.io/internal/outpost/radius/eap/protocol"
)
type State struct {
@ -18,7 +20,7 @@ type State struct {
ContextCancel context.CancelFunc
}
func NewState(c tctx) *State {
func NewState(c protocol.Context) interface{} {
c.Log().Debug("TLS: new state")
return &State{
RemainingChunks: make([][]byte, 0),