more refactor

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-23 17:59:00 +02:00
parent 4571f5e644
commit 8da54d5811
12 changed files with 72 additions and 65 deletions

View File

@ -9,7 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/radius/eap" "goauthentik.io/internal/outpost/radius/eap/protocol"
) )
func parseCIDRs(raw string) []*net.IPNet { func parseCIDRs(raw string) []*net.IPNet {
@ -45,7 +45,7 @@ func (rs *RadiusServer) Refresh() error {
providers := make(map[int32]*ProviderInstance) providers := make(map[int32]*ProviderInstance)
for _, provider := range apiProviders { for _, provider := range apiProviders {
existing, ok := rs.providers[provider.Pk] existing, ok := rs.providers[provider.Pk]
state := map[string]*eap.State{} state := map[string]*protocol.State{}
if ok { if ok {
state = existing.eapState state = existing.eapState
} }

View File

@ -16,13 +16,16 @@ type context struct {
endModifier func(p *radius.Packet) *radius.Packet endModifier func(p *radius.Packet) *radius.Packet
} }
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload } func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
func (ctx *context) Packet() *radius.Request { return ctx.req } func (ctx *context) Packet() *radius.Request { return ctx.req }
func (ctx *context) ProtocolSettings() interface{} { return ctx.settings } func (ctx *context) ProtocolSettings() any { return ctx.settings }
func (ctx *context) GetProtocolState(p protocol.Type) interface{} { return ctx.typeState[p] } func (ctx *context) GetProtocolState(p protocol.Type) any { return ctx.typeState[p] }
func (ctx *context) SetProtocolState(p protocol.Type, st interface{}) { ctx.typeState[p] = st } func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p] = st }
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil } func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
func (ctx *context) Log() *log.Entry { return ctx.log } func (ctx *context) Log() *log.Entry { return ctx.log }
func (ctx *context) HandleInnerEAP(protocol.Payload) protocol.Payload {
return nil
}
func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context { func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context {
return &context{ return &context{

View File

@ -77,7 +77,7 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
st := p.stm.GetEAPState(p.state) st := p.stm.GetEAPState(p.state)
if st == nil { if st == nil {
log.Debug("Root-EAP: blank state") log.Debug("Root-EAP: blank state")
st = BlankState(p.stm.GetEAPSettings()) st = protocol.BlankState(p.stm.GetEAPSettings())
} }
nextChallengeToOffer, err := st.GetNextProtocol() nextChallengeToOffer, err := st.GetNextProtocol()

View File

@ -10,12 +10,12 @@ import (
type Packet struct { type Packet struct {
eap *eap.Payload eap *eap.Payload
stm StateManager stm protocol.StateManager
state string state string
endModifier func(p *radius.Packet) *radius.Packet endModifier func(p *radius.Packet) *radius.Packet
} }
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) { func emptyPayload(stm protocol.StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) {
for _, cons := range stm.GetEAPSettings().Protocols { for _, cons := range stm.GetEAPSettings().Protocols {
np := cons() np := cons()
if np.Type() == t { if np.Type() == t {
@ -31,7 +31,7 @@ func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol
return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t) return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t)
} }
func Decode(stm StateManager, raw []byte) (*Packet, error) { func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
packet := &Packet{ packet := &Packet{
eap: &eap.Payload{}, eap: &eap.Payload{},
stm: stm, stm: stm,

View File

@ -14,6 +14,10 @@ const (
StatusNextProtocol StatusNextProtocol
) )
type StateProtocol interface {
Payload
}
type Context interface { type Context interface {
Packet() *radius.Request Packet() *radius.Request
RootPayload() Payload RootPayload() Payload
@ -24,6 +28,7 @@ type Context interface {
SetProtocolState(p Type, s interface{}) SetProtocolState(p Type, s interface{})
IsProtocolStart(p Type) bool IsProtocolStart(p Type) bool
HandleInnerEAP(Payload) Payload
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
Log() *log.Entry Log() *log.Entry

View File

@ -24,30 +24,30 @@ type Payload struct {
RawPayload []byte RawPayload []byte
} }
func (ip *Payload) Type() protocol.Type { func (p *Payload) Type() protocol.Type {
return TypeEAP return TypeEAP
} }
func (ip *Payload) Offerable() bool { func (p *Payload) Offerable() bool {
return false return false
} }
func (packet *Payload) Decode(raw []byte) error { func (p *Payload) Decode(raw []byte) error {
packet.Code = protocol.Code(raw[0]) p.Code = protocol.Code(raw[0])
packet.ID = raw[1] p.ID = raw[1]
packet.Length = binary.BigEndian.Uint16(raw[2:]) p.Length = binary.BigEndian.Uint16(raw[2:])
if packet.Length != uint16(len(raw)) { if p.Length != uint16(len(raw)) {
return fmt.Errorf("mismatched packet length; got %d, expected %d", packet.Length, uint16(len(raw))) return fmt.Errorf("mismatched packet length; got %d, expected %d", p.Length, uint16(len(raw)))
} }
if len(raw) > 4 && (packet.Code == protocol.CodeRequest || packet.Code == protocol.CodeResponse) { if len(raw) > 4 && (p.Code == protocol.CodeRequest || p.Code == protocol.CodeResponse) {
packet.MsgType = protocol.Type(raw[4]) p.MsgType = protocol.Type(raw[4])
} }
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Trace("EAP: decode raw") log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", p.Payload)).Trace("EAP: decode raw")
packet.RawPayload = raw[5:] p.RawPayload = raw[5:]
if packet.Payload == nil { if p.Payload == nil {
return nil return nil
} }
err := packet.Payload.Decode(raw[5:]) err := p.Payload.Decode(raw[5:])
if err != nil { if err != nil {
return err return err
} }

View File

@ -12,26 +12,26 @@ type Payload struct {
Identity string Identity string
} }
func (ip *Payload) Type() protocol.Type { func (p *Payload) Type() protocol.Type {
return TypeIdentity return TypeIdentity
} }
func (ip *Payload) Decode(raw []byte) error { func (p *Payload) Decode(raw []byte) error {
ip.Identity = string(raw) p.Identity = string(raw)
return nil return nil
} }
func (ip *Payload) Encode() ([]byte, error) { func (p *Payload) Encode() ([]byte, error) {
return []byte{}, nil return []byte{}, nil
} }
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeIdentity) { if ctx.IsProtocolStart(TypeIdentity) {
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil) ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
} }
return nil return nil
} }
func (ip *Payload) Offerable() bool { func (p *Payload) Offerable() bool {
return false return false
} }

View File

@ -12,26 +12,26 @@ type Payload struct {
DesiredType protocol.Type DesiredType protocol.Type
} }
func (ln *Payload) Type() protocol.Type { func (p *Payload) Type() protocol.Type {
return TypeLegacyNAK return TypeLegacyNAK
} }
func (ln *Payload) Decode(raw []byte) error { func (p *Payload) Decode(raw []byte) error {
ln.DesiredType = protocol.Type(raw[0]) p.DesiredType = protocol.Type(raw[0])
return nil return nil
} }
func (ln *Payload) Encode() ([]byte, error) { func (p *Payload) Encode() ([]byte, error) {
return []byte{byte(ln.DesiredType)}, nil return []byte{byte(p.DesiredType)}, nil
} }
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeLegacyNAK) { if ctx.IsProtocolStart(TypeLegacyNAK) {
ctx.EndInnerProtocol(protocol.StatusError, nil) ctx.EndInnerProtocol(protocol.StatusError, nil)
} }
return nil return nil
} }
func (ln *Payload) Offerable() bool { func (p *Payload) Offerable() bool {
return false return false
} }

View File

@ -96,6 +96,7 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
ID: rootEap.ID + 1, ID: rootEap.ID + 1,
} }
} }
return ep return ep
} }

View File

@ -1,36 +1,34 @@
package eap package protocol
import ( import (
"errors" "errors"
"slices" "slices"
"goauthentik.io/internal/outpost/radius/eap/protocol"
) )
type ProtocolConstructor func() protocol.Payload
type Settings struct {
Protocols []ProtocolConstructor
ProtocolPriority []protocol.Type
ProtocolSettings map[protocol.Type]interface{}
}
type StateManager interface { type StateManager interface {
GetEAPSettings() Settings GetEAPSettings() Settings
GetEAPState(string) *State GetEAPState(string) *State
SetEAPState(string, *State) SetEAPState(string, *State)
} }
type ProtocolConstructor func() Payload
type Settings struct {
Protocols []ProtocolConstructor
ProtocolPriority []Type
ProtocolSettings map[Type]interface{}
}
type State struct { type State struct {
Protocols []ProtocolConstructor Protocols []ProtocolConstructor
ProtocolIndex int ProtocolIndex int
ProtocolPriority []protocol.Type ProtocolPriority []Type
TypeState map[protocol.Type]any TypeState map[Type]any
} }
func (st *State) GetNextProtocol() (protocol.Type, error) { func (st *State) GetNextProtocol() (Type, error) {
if st.ProtocolIndex >= len(st.ProtocolPriority) { if st.ProtocolIndex >= len(st.ProtocolPriority) {
return protocol.Type(0), errors.New("no more protocols to offer") return Type(0), errors.New("no more protocols to offer")
} }
return st.ProtocolPriority[st.ProtocolIndex], nil return st.ProtocolPriority[st.ProtocolIndex], nil
} }
@ -39,6 +37,6 @@ func BlankState(settings Settings) *State {
return &State{ return &State{
Protocols: slices.Clone(settings.Protocols), Protocols: slices.Clone(settings.Protocols),
ProtocolPriority: slices.Clone(settings.ProtocolPriority), ProtocolPriority: slices.Clone(settings.ProtocolPriority),
TypeState: map[protocol.Type]any{}, TypeState: map[Type]any{},
} }
} }

View File

@ -122,35 +122,35 @@ func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *Rad
ep.HandleRadiusPacket(w, r.Request) ep.HandleRadiusPacket(w, r.Request)
} }
func (pi *ProviderInstance) GetEAPState(key string) *eap.State { func (pi *ProviderInstance) GetEAPState(key string) *protocol.State {
return pi.eapState[key] return pi.eapState[key]
} }
func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) { func (pi *ProviderInstance) SetEAPState(key string, state *protocol.State) {
pi.eapState[key] = state pi.eapState[key] = state
} }
func (pi *ProviderInstance) GetEAPSettings() eap.Settings { func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
protocols := []eap.ProtocolConstructor{ protocols := []protocol.ProtocolConstructor{
identity.Protocol, identity.Protocol,
legacy_nak.Protocol, legacy_nak.Protocol,
} }
certId := pi.certId certId := pi.certId
if certId == "" { if certId == "" {
return eap.Settings{ return protocol.Settings{
Protocols: protocols, Protocols: protocols,
} }
} }
cert := pi.s.cryptoStore.Get(certId) cert := pi.s.cryptoStore.Get(certId)
if cert == nil { if cert == nil {
return eap.Settings{ return protocol.Settings{
Protocols: protocols, Protocols: protocols,
} }
} }
return eap.Settings{ return protocol.Settings{
Protocols: append(protocols, tls.Protocol, peap.Protocol), Protocols: append(protocols, tls.Protocol, peap.Protocol),
ProtocolPriority: []protocol.Type{ ProtocolPriority: []protocol.Type{
tls.TypeTLS, tls.TypeTLS,

View File

@ -9,7 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/config" "goauthentik.io/internal/config"
"goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/radius/eap" "goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/metrics" "goauthentik.io/internal/outpost/radius/metrics"
"layeh.com/radius" "layeh.com/radius"
@ -26,7 +26,7 @@ type ProviderInstance struct {
certId string certId string
s *RadiusServer s *RadiusServer
log *log.Entry log *log.Entry
eapState map[string]*eap.State eapState map[string]*protocol.State
} }
type RadiusServer struct { type RadiusServer struct {