more refactor

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-23 17:59:00 +02:00
parent 4571f5e644
commit 8da54d5811
12 changed files with 72 additions and 65 deletions

View File

@ -9,7 +9,7 @@ import (
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/radius/eap"
"goauthentik.io/internal/outpost/radius/eap/protocol"
)
func parseCIDRs(raw string) []*net.IPNet {
@ -45,7 +45,7 @@ func (rs *RadiusServer) Refresh() error {
providers := make(map[int32]*ProviderInstance)
for _, provider := range apiProviders {
existing, ok := rs.providers[provider.Pk]
state := map[string]*eap.State{}
state := map[string]*protocol.State{}
if ok {
state = existing.eapState
}

View File

@ -16,13 +16,16 @@ type context struct {
endModifier func(p *radius.Packet) *radius.Packet
}
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
func (ctx *context) Packet() *radius.Request { return ctx.req }
func (ctx *context) ProtocolSettings() interface{} { return ctx.settings }
func (ctx *context) GetProtocolState(p protocol.Type) interface{} { return ctx.typeState[p] }
func (ctx *context) SetProtocolState(p protocol.Type, st interface{}) { ctx.typeState[p] = st }
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
func (ctx *context) Log() *log.Entry { return ctx.log }
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
func (ctx *context) Packet() *radius.Request { return ctx.req }
func (ctx *context) ProtocolSettings() any { return ctx.settings }
func (ctx *context) GetProtocolState(p protocol.Type) any { return ctx.typeState[p] }
func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p] = st }
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
func (ctx *context) Log() *log.Entry { return ctx.log }
func (ctx *context) HandleInnerEAP(protocol.Payload) protocol.Payload {
return nil
}
func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context {
return &context{

View File

@ -77,7 +77,7 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
st := p.stm.GetEAPState(p.state)
if st == nil {
log.Debug("Root-EAP: blank state")
st = BlankState(p.stm.GetEAPSettings())
st = protocol.BlankState(p.stm.GetEAPSettings())
}
nextChallengeToOffer, err := st.GetNextProtocol()

View File

@ -10,12 +10,12 @@ import (
type Packet struct {
eap *eap.Payload
stm StateManager
stm protocol.StateManager
state string
endModifier func(p *radius.Packet) *radius.Packet
}
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) {
func emptyPayload(stm protocol.StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) {
for _, cons := range stm.GetEAPSettings().Protocols {
np := cons()
if np.Type() == t {
@ -31,7 +31,7 @@ func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol
return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t)
}
func Decode(stm StateManager, raw []byte) (*Packet, error) {
func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
packet := &Packet{
eap: &eap.Payload{},
stm: stm,

View File

@ -14,6 +14,10 @@ const (
StatusNextProtocol
)
type StateProtocol interface {
Payload
}
type Context interface {
Packet() *radius.Request
RootPayload() Payload
@ -24,6 +28,7 @@ type Context interface {
SetProtocolState(p Type, s interface{})
IsProtocolStart(p Type) bool
HandleInnerEAP(Payload) Payload
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
Log() *log.Entry

View File

@ -24,30 +24,30 @@ type Payload struct {
RawPayload []byte
}
func (ip *Payload) Type() protocol.Type {
func (p *Payload) Type() protocol.Type {
return TypeEAP
}
func (ip *Payload) Offerable() bool {
func (p *Payload) Offerable() bool {
return false
}
func (packet *Payload) Decode(raw []byte) error {
packet.Code = protocol.Code(raw[0])
packet.ID = raw[1]
packet.Length = binary.BigEndian.Uint16(raw[2:])
if packet.Length != uint16(len(raw)) {
return fmt.Errorf("mismatched packet length; got %d, expected %d", packet.Length, uint16(len(raw)))
func (p *Payload) Decode(raw []byte) error {
p.Code = protocol.Code(raw[0])
p.ID = raw[1]
p.Length = binary.BigEndian.Uint16(raw[2:])
if p.Length != uint16(len(raw)) {
return fmt.Errorf("mismatched packet length; got %d, expected %d", p.Length, uint16(len(raw)))
}
if len(raw) > 4 && (packet.Code == protocol.CodeRequest || packet.Code == protocol.CodeResponse) {
packet.MsgType = protocol.Type(raw[4])
if len(raw) > 4 && (p.Code == protocol.CodeRequest || p.Code == protocol.CodeResponse) {
p.MsgType = protocol.Type(raw[4])
}
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Trace("EAP: decode raw")
packet.RawPayload = raw[5:]
if packet.Payload == nil {
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", p.Payload)).Trace("EAP: decode raw")
p.RawPayload = raw[5:]
if p.Payload == nil {
return nil
}
err := packet.Payload.Decode(raw[5:])
err := p.Payload.Decode(raw[5:])
if err != nil {
return err
}

View File

@ -12,26 +12,26 @@ type Payload struct {
Identity string
}
func (ip *Payload) Type() protocol.Type {
func (p *Payload) Type() protocol.Type {
return TypeIdentity
}
func (ip *Payload) Decode(raw []byte) error {
ip.Identity = string(raw)
func (p *Payload) Decode(raw []byte) error {
p.Identity = string(raw)
return nil
}
func (ip *Payload) Encode() ([]byte, error) {
func (p *Payload) Encode() ([]byte, error) {
return []byte{}, nil
}
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeIdentity) {
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
}
return nil
}
func (ip *Payload) Offerable() bool {
func (p *Payload) Offerable() bool {
return false
}

View File

@ -12,26 +12,26 @@ type Payload struct {
DesiredType protocol.Type
}
func (ln *Payload) Type() protocol.Type {
func (p *Payload) Type() protocol.Type {
return TypeLegacyNAK
}
func (ln *Payload) Decode(raw []byte) error {
ln.DesiredType = protocol.Type(raw[0])
func (p *Payload) Decode(raw []byte) error {
p.DesiredType = protocol.Type(raw[0])
return nil
}
func (ln *Payload) Encode() ([]byte, error) {
return []byte{byte(ln.DesiredType)}, nil
func (p *Payload) Encode() ([]byte, error) {
return []byte{byte(p.DesiredType)}, nil
}
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload {
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart(TypeLegacyNAK) {
ctx.EndInnerProtocol(protocol.StatusError, nil)
}
return nil
}
func (ln *Payload) Offerable() bool {
func (p *Payload) Offerable() bool {
return false
}

View File

@ -96,6 +96,7 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
ID: rootEap.ID + 1,
}
}
return ep
}

View File

@ -1,36 +1,34 @@
package eap
package protocol
import (
"errors"
"slices"
"goauthentik.io/internal/outpost/radius/eap/protocol"
)
type ProtocolConstructor func() protocol.Payload
type Settings struct {
Protocols []ProtocolConstructor
ProtocolPriority []protocol.Type
ProtocolSettings map[protocol.Type]interface{}
}
type StateManager interface {
GetEAPSettings() Settings
GetEAPState(string) *State
SetEAPState(string, *State)
}
type ProtocolConstructor func() Payload
type Settings struct {
Protocols []ProtocolConstructor
ProtocolPriority []Type
ProtocolSettings map[Type]interface{}
}
type State struct {
Protocols []ProtocolConstructor
ProtocolIndex int
ProtocolPriority []protocol.Type
TypeState map[protocol.Type]any
ProtocolPriority []Type
TypeState map[Type]any
}
func (st *State) GetNextProtocol() (protocol.Type, error) {
func (st *State) GetNextProtocol() (Type, error) {
if st.ProtocolIndex >= len(st.ProtocolPriority) {
return protocol.Type(0), errors.New("no more protocols to offer")
return Type(0), errors.New("no more protocols to offer")
}
return st.ProtocolPriority[st.ProtocolIndex], nil
}
@ -39,6 +37,6 @@ func BlankState(settings Settings) *State {
return &State{
Protocols: slices.Clone(settings.Protocols),
ProtocolPriority: slices.Clone(settings.ProtocolPriority),
TypeState: map[protocol.Type]any{},
TypeState: map[Type]any{},
}
}

View File

@ -122,35 +122,35 @@ func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *Rad
ep.HandleRadiusPacket(w, r.Request)
}
func (pi *ProviderInstance) GetEAPState(key string) *eap.State {
func (pi *ProviderInstance) GetEAPState(key string) *protocol.State {
return pi.eapState[key]
}
func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
func (pi *ProviderInstance) SetEAPState(key string, state *protocol.State) {
pi.eapState[key] = state
}
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
protocols := []eap.ProtocolConstructor{
func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
protocols := []protocol.ProtocolConstructor{
identity.Protocol,
legacy_nak.Protocol,
}
certId := pi.certId
if certId == "" {
return eap.Settings{
return protocol.Settings{
Protocols: protocols,
}
}
cert := pi.s.cryptoStore.Get(certId)
if cert == nil {
return eap.Settings{
return protocol.Settings{
Protocols: protocols,
}
}
return eap.Settings{
return protocol.Settings{
Protocols: append(protocols, tls.Protocol, peap.Protocol),
ProtocolPriority: []protocol.Type{
tls.TypeTLS,

View File

@ -9,7 +9,7 @@ import (
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/config"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/radius/eap"
"goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/metrics"
"layeh.com/radius"
@ -26,7 +26,7 @@ type ProviderInstance struct {
certId string
s *RadiusServer
log *log.Entry
eapState map[string]*eap.State
eapState map[string]*protocol.State
}
type RadiusServer struct {