@ -9,7 +9,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/radius/eap"
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
)
|
||||
|
||||
func parseCIDRs(raw string) []*net.IPNet {
|
||||
@ -45,7 +45,7 @@ func (rs *RadiusServer) Refresh() error {
|
||||
providers := make(map[int32]*ProviderInstance)
|
||||
for _, provider := range apiProviders {
|
||||
existing, ok := rs.providers[provider.Pk]
|
||||
state := map[string]*eap.State{}
|
||||
state := map[string]*protocol.State{}
|
||||
if ok {
|
||||
state = existing.eapState
|
||||
}
|
||||
|
@ -16,13 +16,16 @@ type context struct {
|
||||
endModifier func(p *radius.Packet) *radius.Packet
|
||||
}
|
||||
|
||||
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
|
||||
func (ctx *context) Packet() *radius.Request { return ctx.req }
|
||||
func (ctx *context) ProtocolSettings() interface{} { return ctx.settings }
|
||||
func (ctx *context) GetProtocolState(p protocol.Type) interface{} { return ctx.typeState[p] }
|
||||
func (ctx *context) SetProtocolState(p protocol.Type, st interface{}) { ctx.typeState[p] = st }
|
||||
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
|
||||
func (ctx *context) Log() *log.Entry { return ctx.log }
|
||||
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
|
||||
func (ctx *context) Packet() *radius.Request { return ctx.req }
|
||||
func (ctx *context) ProtocolSettings() any { return ctx.settings }
|
||||
func (ctx *context) GetProtocolState(p protocol.Type) any { return ctx.typeState[p] }
|
||||
func (ctx *context) SetProtocolState(p protocol.Type, st any) { ctx.typeState[p] = st }
|
||||
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
|
||||
func (ctx *context) Log() *log.Entry { return ctx.log }
|
||||
func (ctx *context) HandleInnerEAP(protocol.Payload) protocol.Payload {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context {
|
||||
return &context{
|
||||
|
@ -77,7 +77,7 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
||||
st := p.stm.GetEAPState(p.state)
|
||||
if st == nil {
|
||||
log.Debug("Root-EAP: blank state")
|
||||
st = BlankState(p.stm.GetEAPSettings())
|
||||
st = protocol.BlankState(p.stm.GetEAPSettings())
|
||||
}
|
||||
|
||||
nextChallengeToOffer, err := st.GetNextProtocol()
|
||||
|
@ -10,12 +10,12 @@ import (
|
||||
|
||||
type Packet struct {
|
||||
eap *eap.Payload
|
||||
stm StateManager
|
||||
stm protocol.StateManager
|
||||
state string
|
||||
endModifier func(p *radius.Packet) *radius.Packet
|
||||
}
|
||||
|
||||
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) {
|
||||
func emptyPayload(stm protocol.StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) {
|
||||
for _, cons := range stm.GetEAPSettings().Protocols {
|
||||
np := cons()
|
||||
if np.Type() == t {
|
||||
@ -31,7 +31,7 @@ func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol
|
||||
return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t)
|
||||
}
|
||||
|
||||
func Decode(stm StateManager, raw []byte) (*Packet, error) {
|
||||
func Decode(stm protocol.StateManager, raw []byte) (*Packet, error) {
|
||||
packet := &Packet{
|
||||
eap: &eap.Payload{},
|
||||
stm: stm,
|
||||
|
@ -14,6 +14,10 @@ const (
|
||||
StatusNextProtocol
|
||||
)
|
||||
|
||||
type StateProtocol interface {
|
||||
Payload
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Packet() *radius.Request
|
||||
RootPayload() Payload
|
||||
@ -24,6 +28,7 @@ type Context interface {
|
||||
SetProtocolState(p Type, s interface{})
|
||||
IsProtocolStart(p Type) bool
|
||||
|
||||
HandleInnerEAP(Payload) Payload
|
||||
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
||||
|
||||
Log() *log.Entry
|
||||
|
@ -24,30 +24,30 @@ type Payload struct {
|
||||
RawPayload []byte
|
||||
}
|
||||
|
||||
func (ip *Payload) Type() protocol.Type {
|
||||
func (p *Payload) Type() protocol.Type {
|
||||
return TypeEAP
|
||||
}
|
||||
|
||||
func (ip *Payload) Offerable() bool {
|
||||
func (p *Payload) Offerable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (packet *Payload) Decode(raw []byte) error {
|
||||
packet.Code = protocol.Code(raw[0])
|
||||
packet.ID = raw[1]
|
||||
packet.Length = binary.BigEndian.Uint16(raw[2:])
|
||||
if packet.Length != uint16(len(raw)) {
|
||||
return fmt.Errorf("mismatched packet length; got %d, expected %d", packet.Length, uint16(len(raw)))
|
||||
func (p *Payload) Decode(raw []byte) error {
|
||||
p.Code = protocol.Code(raw[0])
|
||||
p.ID = raw[1]
|
||||
p.Length = binary.BigEndian.Uint16(raw[2:])
|
||||
if p.Length != uint16(len(raw)) {
|
||||
return fmt.Errorf("mismatched packet length; got %d, expected %d", p.Length, uint16(len(raw)))
|
||||
}
|
||||
if len(raw) > 4 && (packet.Code == protocol.CodeRequest || packet.Code == protocol.CodeResponse) {
|
||||
packet.MsgType = protocol.Type(raw[4])
|
||||
if len(raw) > 4 && (p.Code == protocol.CodeRequest || p.Code == protocol.CodeResponse) {
|
||||
p.MsgType = protocol.Type(raw[4])
|
||||
}
|
||||
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Trace("EAP: decode raw")
|
||||
packet.RawPayload = raw[5:]
|
||||
if packet.Payload == nil {
|
||||
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", p.Payload)).Trace("EAP: decode raw")
|
||||
p.RawPayload = raw[5:]
|
||||
if p.Payload == nil {
|
||||
return nil
|
||||
}
|
||||
err := packet.Payload.Decode(raw[5:])
|
||||
err := p.Payload.Decode(raw[5:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -12,26 +12,26 @@ type Payload struct {
|
||||
Identity string
|
||||
}
|
||||
|
||||
func (ip *Payload) Type() protocol.Type {
|
||||
func (p *Payload) Type() protocol.Type {
|
||||
return TypeIdentity
|
||||
}
|
||||
|
||||
func (ip *Payload) Decode(raw []byte) error {
|
||||
ip.Identity = string(raw)
|
||||
func (p *Payload) Decode(raw []byte) error {
|
||||
p.Identity = string(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *Payload) Encode() ([]byte, error) {
|
||||
func (p *Payload) Encode() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart(TypeIdentity) {
|
||||
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *Payload) Offerable() bool {
|
||||
func (p *Payload) Offerable() bool {
|
||||
return false
|
||||
}
|
||||
|
@ -12,26 +12,26 @@ type Payload struct {
|
||||
DesiredType protocol.Type
|
||||
}
|
||||
|
||||
func (ln *Payload) Type() protocol.Type {
|
||||
func (p *Payload) Type() protocol.Type {
|
||||
return TypeLegacyNAK
|
||||
}
|
||||
|
||||
func (ln *Payload) Decode(raw []byte) error {
|
||||
ln.DesiredType = protocol.Type(raw[0])
|
||||
func (p *Payload) Decode(raw []byte) error {
|
||||
p.DesiredType = protocol.Type(raw[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *Payload) Encode() ([]byte, error) {
|
||||
return []byte{byte(ln.DesiredType)}, nil
|
||||
func (p *Payload) Encode() ([]byte, error) {
|
||||
return []byte{byte(p.DesiredType)}, nil
|
||||
}
|
||||
|
||||
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart(TypeLegacyNAK) {
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *Payload) Offerable() bool {
|
||||
func (p *Payload) Offerable() bool {
|
||||
return false
|
||||
}
|
||||
|
@ -96,6 +96,7 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
ID: rootEap.ID + 1,
|
||||
}
|
||||
}
|
||||
|
||||
return ep
|
||||
}
|
||||
|
||||
|
@ -1,36 +1,34 @@
|
||||
package eap
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
)
|
||||
|
||||
type ProtocolConstructor func() protocol.Payload
|
||||
|
||||
type Settings struct {
|
||||
Protocols []ProtocolConstructor
|
||||
ProtocolPriority []protocol.Type
|
||||
ProtocolSettings map[protocol.Type]interface{}
|
||||
}
|
||||
|
||||
type StateManager interface {
|
||||
GetEAPSettings() Settings
|
||||
GetEAPState(string) *State
|
||||
SetEAPState(string, *State)
|
||||
}
|
||||
|
||||
type ProtocolConstructor func() Payload
|
||||
|
||||
type Settings struct {
|
||||
Protocols []ProtocolConstructor
|
||||
ProtocolPriority []Type
|
||||
ProtocolSettings map[Type]interface{}
|
||||
}
|
||||
|
||||
type State struct {
|
||||
Protocols []ProtocolConstructor
|
||||
ProtocolIndex int
|
||||
ProtocolPriority []protocol.Type
|
||||
TypeState map[protocol.Type]any
|
||||
ProtocolPriority []Type
|
||||
TypeState map[Type]any
|
||||
}
|
||||
|
||||
func (st *State) GetNextProtocol() (protocol.Type, error) {
|
||||
func (st *State) GetNextProtocol() (Type, error) {
|
||||
if st.ProtocolIndex >= len(st.ProtocolPriority) {
|
||||
return protocol.Type(0), errors.New("no more protocols to offer")
|
||||
return Type(0), errors.New("no more protocols to offer")
|
||||
}
|
||||
return st.ProtocolPriority[st.ProtocolIndex], nil
|
||||
}
|
||||
@ -39,6 +37,6 @@ func BlankState(settings Settings) *State {
|
||||
return &State{
|
||||
Protocols: slices.Clone(settings.Protocols),
|
||||
ProtocolPriority: slices.Clone(settings.ProtocolPriority),
|
||||
TypeState: map[protocol.Type]any{},
|
||||
TypeState: map[Type]any{},
|
||||
}
|
||||
}
|
@ -122,35 +122,35 @@ func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *Rad
|
||||
ep.HandleRadiusPacket(w, r.Request)
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetEAPState(key string) *eap.State {
|
||||
func (pi *ProviderInstance) GetEAPState(key string) *protocol.State {
|
||||
return pi.eapState[key]
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
|
||||
func (pi *ProviderInstance) SetEAPState(key string, state *protocol.State) {
|
||||
pi.eapState[key] = state
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
||||
protocols := []eap.ProtocolConstructor{
|
||||
func (pi *ProviderInstance) GetEAPSettings() protocol.Settings {
|
||||
protocols := []protocol.ProtocolConstructor{
|
||||
identity.Protocol,
|
||||
legacy_nak.Protocol,
|
||||
}
|
||||
|
||||
certId := pi.certId
|
||||
if certId == "" {
|
||||
return eap.Settings{
|
||||
return protocol.Settings{
|
||||
Protocols: protocols,
|
||||
}
|
||||
}
|
||||
|
||||
cert := pi.s.cryptoStore.Get(certId)
|
||||
if cert == nil {
|
||||
return eap.Settings{
|
||||
return protocol.Settings{
|
||||
Protocols: protocols,
|
||||
}
|
||||
}
|
||||
|
||||
return eap.Settings{
|
||||
return protocol.Settings{
|
||||
Protocols: append(protocols, tls.Protocol, peap.Protocol),
|
||||
ProtocolPriority: []protocol.Type{
|
||||
tls.TypeTLS,
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/radius/eap"
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
"goauthentik.io/internal/outpost/radius/metrics"
|
||||
|
||||
"layeh.com/radius"
|
||||
@ -26,7 +26,7 @@ type ProviderInstance struct {
|
||||
certId string
|
||||
s *RadiusServer
|
||||
log *log.Entry
|
||||
eapState map[string]*eap.State
|
||||
eapState map[string]*protocol.State
|
||||
}
|
||||
|
||||
type RadiusServer struct {
|
||||
|
Reference in New Issue
Block a user