diff --git a/internal/outpost/radius/api.go b/internal/outpost/radius/api.go index 92fcdf40e2..06a49deec9 100644 --- a/internal/outpost/radius/api.go +++ b/internal/outpost/radius/api.go @@ -9,7 +9,7 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/ak" - "goauthentik.io/internal/outpost/radius/eap" + "goauthentik.io/internal/outpost/radius/eap/protocol" ) func parseCIDRs(raw string) []*net.IPNet { @@ -45,7 +45,7 @@ func (rs *RadiusServer) Refresh() error { providers := make(map[int32]*ProviderInstance) for _, provider := range apiProviders { existing, ok := rs.providers[provider.Pk] - state := map[string]*eap.State{} + state := map[string]*protocol.State{} if ok { state = existing.eapState } diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go index 883fb7c816..4a2804ea81 100644 --- a/internal/outpost/radius/eap/context.go +++ b/internal/outpost/radius/eap/context.go @@ -16,13 +16,16 @@ type context struct { endModifier func(p *radius.Packet) *radius.Packet } -func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload } -func (ctx *context) Packet() *radius.Request { return ctx.req } -func (ctx *context) ProtocolSettings() interface{} { return ctx.settings } -func (ctx *context) GetProtocolState(p protocol.Type) interface{} { return ctx.typeState[p] } -func (ctx *context) SetProtocolState(p protocol.Type, st interface{}) { ctx.typeState[p] = st } -func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil } -func (ctx *context) Log() *log.Entry { return ctx.log } +func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload } +func (ctx *context) Packet() *radius.Request { return ctx.req } +func (ctx *context) ProtocolSettings() any { return ctx.settings } +func (ctx *context) GetProtocolState(p protocol.Type) any { return ctx.typeState[p] } +func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p] = st } +func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil } +func (ctx *context) Log() *log.Entry { return ctx.log } +func (ctx *context) HandleInnerEAP(protocol.Payload) protocol.Payload { + return nil +} func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context { return &context{ diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 73cbca9998..55bc53099c 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -77,7 +77,7 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { st := p.stm.GetEAPState(p.state) if st == nil { log.Debug("Root-EAP: blank state") - st = BlankState(p.stm.GetEAPSettings()) + st = protocol.BlankState(p.stm.GetEAPSettings()) } nextChallengeToOffer, err := st.GetNextProtocol() diff --git a/internal/outpost/radius/eap/packet.go b/internal/outpost/radius/eap/packet.go index 7aec9abb0a..e249c8a2c3 100644 --- a/internal/outpost/radius/eap/packet.go +++ b/internal/outpost/radius/eap/packet.go @@ -10,12 +10,12 @@ import ( type Packet struct { eap *eap.Payload - stm StateManager + stm protocol.StateManager state string endModifier func(p *radius.Packet) *radius.Packet } -func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) { +func emptyPayload(stm protocol.StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) { for _, cons := range stm.GetEAPSettings().Protocols { np := cons() if np.Type() == t { @@ -31,7 +31,7 @@ func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t) } -func Decode(stm StateManager, raw []byte) (*Packet, error) { +func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) { packet := &Packet{ eap: &eap.Payload{}, stm: stm, diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go index 151565289a..a0e3acb206 100644 --- a/internal/outpost/radius/eap/protocol/context.go +++ b/internal/outpost/radius/eap/protocol/context.go @@ -14,6 +14,10 @@ const ( StatusNextProtocol ) +type StateProtocol interface { + Payload +} + type Context interface { Packet() *radius.Request RootPayload() Payload @@ -24,6 +28,7 @@ type Context interface { SetProtocolState(p Type, s interface{}) IsProtocolStart(p Type) bool + HandleInnerEAP(Payload) Payload EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) Log() *log.Entry diff --git a/internal/outpost/radius/eap/protocol/eap/payload.go b/internal/outpost/radius/eap/protocol/eap/payload.go index f7e9096f66..6a519276dd 100644 --- a/internal/outpost/radius/eap/protocol/eap/payload.go +++ b/internal/outpost/radius/eap/protocol/eap/payload.go @@ -24,30 +24,30 @@ type Payload struct { RawPayload []byte } -func (ip *Payload) Type() protocol.Type { +func (p *Payload) Type() protocol.Type { return TypeEAP } -func (ip *Payload) Offerable() bool { +func (p *Payload) Offerable() bool { return false } -func (packet *Payload) Decode(raw []byte) error { - packet.Code = protocol.Code(raw[0]) - packet.ID = raw[1] - packet.Length = binary.BigEndian.Uint16(raw[2:]) - if packet.Length != uint16(len(raw)) { - return fmt.Errorf("mismatched packet length; got %d, expected %d", packet.Length, uint16(len(raw))) +func (p *Payload) Decode(raw []byte) error { + p.Code = protocol.Code(raw[0]) + p.ID = raw[1] + p.Length = binary.BigEndian.Uint16(raw[2:]) + if p.Length != uint16(len(raw)) { + return fmt.Errorf("mismatched packet length; got %d, expected %d", p.Length, uint16(len(raw))) } - if len(raw) > 4 && (packet.Code == protocol.CodeRequest || packet.Code == protocol.CodeResponse) { - packet.MsgType = protocol.Type(raw[4]) + if len(raw) > 4 && (p.Code == protocol.CodeRequest || p.Code == protocol.CodeResponse) { + p.MsgType = protocol.Type(raw[4]) } - log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Trace("EAP: decode raw") - packet.RawPayload = raw[5:] - if packet.Payload == nil { + log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", p.Payload)).Trace("EAP: decode raw") + p.RawPayload = raw[5:] + if p.Payload == nil { return nil } - err := packet.Payload.Decode(raw[5:]) + err := p.Payload.Decode(raw[5:]) if err != nil { return err } diff --git a/internal/outpost/radius/eap/protocol/identity/payload.go b/internal/outpost/radius/eap/protocol/identity/payload.go index ff06c59f51..3d748eed8a 100644 --- a/internal/outpost/radius/eap/protocol/identity/payload.go +++ b/internal/outpost/radius/eap/protocol/identity/payload.go @@ -12,26 +12,26 @@ type Payload struct { Identity string } -func (ip *Payload) Type() protocol.Type { +func (p *Payload) Type() protocol.Type { return TypeIdentity } -func (ip *Payload) Decode(raw []byte) error { - ip.Identity = string(raw) +func (p *Payload) Decode(raw []byte) error { + p.Identity = string(raw) return nil } -func (ip *Payload) Encode() ([]byte, error) { +func (p *Payload) Encode() ([]byte, error) { return []byte{}, nil } -func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload { +func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { if ctx.IsProtocolStart(TypeIdentity) { ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil) } return nil } -func (ip *Payload) Offerable() bool { +func (p *Payload) Offerable() bool { return false } diff --git a/internal/outpost/radius/eap/protocol/legacy_nak/payload.go b/internal/outpost/radius/eap/protocol/legacy_nak/payload.go index b7ef8a366d..662d8d3239 100644 --- a/internal/outpost/radius/eap/protocol/legacy_nak/payload.go +++ b/internal/outpost/radius/eap/protocol/legacy_nak/payload.go @@ -12,26 +12,26 @@ type Payload struct { DesiredType protocol.Type } -func (ln *Payload) Type() protocol.Type { +func (p *Payload) Type() protocol.Type { return TypeLegacyNAK } -func (ln *Payload) Decode(raw []byte) error { - ln.DesiredType = protocol.Type(raw[0]) +func (p *Payload) Decode(raw []byte) error { + p.DesiredType = protocol.Type(raw[0]) return nil } -func (ln *Payload) Encode() ([]byte, error) { - return []byte{byte(ln.DesiredType)}, nil +func (p *Payload) Encode() ([]byte, error) { + return []byte{byte(p.DesiredType)}, nil } -func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload { +func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { if ctx.IsProtocolStart(TypeLegacyNAK) { ctx.EndInnerProtocol(protocol.StatusError, nil) } return nil } -func (ln *Payload) Offerable() bool { +func (p *Payload) Offerable() bool { return false } diff --git a/internal/outpost/radius/eap/protocol/peap/payload.go b/internal/outpost/radius/eap/protocol/peap/payload.go index 358b181d4d..4f7639e481 100644 --- a/internal/outpost/radius/eap/protocol/peap/payload.go +++ b/internal/outpost/radius/eap/protocol/peap/payload.go @@ -96,6 +96,7 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { ID: rootEap.ID + 1, } } + return ep } diff --git a/internal/outpost/radius/eap/state.go b/internal/outpost/radius/eap/protocol/state.go similarity index 56% rename from internal/outpost/radius/eap/state.go rename to internal/outpost/radius/eap/protocol/state.go index ed8793ba32..b50bf227c0 100644 --- a/internal/outpost/radius/eap/state.go +++ b/internal/outpost/radius/eap/protocol/state.go @@ -1,36 +1,34 @@ -package eap +package protocol import ( "errors" "slices" - - "goauthentik.io/internal/outpost/radius/eap/protocol" ) -type ProtocolConstructor func() protocol.Payload - -type Settings struct { - Protocols []ProtocolConstructor - ProtocolPriority []protocol.Type - ProtocolSettings map[protocol.Type]interface{} -} - type StateManager interface { GetEAPSettings() Settings GetEAPState(string) *State SetEAPState(string, *State) } +type ProtocolConstructor func() Payload + +type Settings struct { + Protocols []ProtocolConstructor + ProtocolPriority []Type + ProtocolSettings map[Type]interface{} +} + type State struct { Protocols []ProtocolConstructor ProtocolIndex int - ProtocolPriority []protocol.Type - TypeState map[protocol.Type]any + ProtocolPriority []Type + TypeState map[Type]any } -func (st *State) GetNextProtocol() (protocol.Type, error) { +func (st *State) GetNextProtocol() (Type, error) { if st.ProtocolIndex >= len(st.ProtocolPriority) { - return protocol.Type(0), errors.New("no more protocols to offer") + return Type(0), errors.New("no more protocols to offer") } return st.ProtocolPriority[st.ProtocolIndex], nil } @@ -39,6 +37,6 @@ func BlankState(settings Settings) *State { return &State{ Protocols: slices.Clone(settings.Protocols), ProtocolPriority: slices.Clone(settings.ProtocolPriority), - TypeState: map[protocol.Type]any{}, + TypeState: map[Type]any{}, } } diff --git a/internal/outpost/radius/handle_access_request.go b/internal/outpost/radius/handle_access_request.go index 9a2a59dc9c..a999517d0d 100644 --- a/internal/outpost/radius/handle_access_request.go +++ b/internal/outpost/radius/handle_access_request.go @@ -122,35 +122,35 @@ func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *Rad ep.HandleRadiusPacket(w, r.Request) } -func (pi *ProviderInstance) GetEAPState(key string) *eap.State { +func (pi *ProviderInstance) GetEAPState(key string) *protocol.State { return pi.eapState[key] } -func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) { +func (pi *ProviderInstance) SetEAPState(key string, state *protocol.State) { pi.eapState[key] = state } -func (pi *ProviderInstance) GetEAPSettings() eap.Settings { - protocols := []eap.ProtocolConstructor{ +func (pi *ProviderInstance) GetEAPSettings() protocol.Settings { + protocols := []protocol.ProtocolConstructor{ identity.Protocol, legacy_nak.Protocol, } certId := pi.certId if certId == "" { - return eap.Settings{ + return protocol.Settings{ Protocols: protocols, } } cert := pi.s.cryptoStore.Get(certId) if cert == nil { - return eap.Settings{ + return protocol.Settings{ Protocols: protocols, } } - return eap.Settings{ + return protocol.Settings{ Protocols: append(protocols, tls.Protocol, peap.Protocol), ProtocolPriority: []protocol.Type{ tls.TypeTLS, diff --git a/internal/outpost/radius/radius.go b/internal/outpost/radius/radius.go index 2a68d96bba..39aa309f82 100644 --- a/internal/outpost/radius/radius.go +++ b/internal/outpost/radius/radius.go @@ -9,7 +9,7 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/internal/config" "goauthentik.io/internal/outpost/ak" - "goauthentik.io/internal/outpost/radius/eap" + "goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/metrics" "layeh.com/radius" @@ -26,7 +26,7 @@ type ProviderInstance struct { certId string s *RadiusServer log *log.Entry - eapState map[string]*eap.State + eapState map[string]*protocol.State } type RadiusServer struct {