refactor v1, start support for more protocols and implement nak

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-20 22:39:14 +02:00
parent 8cf8f1e199
commit b6686cff14
12 changed files with 252 additions and 106 deletions

View File

@ -23,10 +23,7 @@ func (ctx context) ProtocolSettings() interface{} {
return ctx.settings return ctx.settings
} }
func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} { func (ctx *context) GetProtocolState() interface{} {
if ctx.state == nil {
ctx.state = def(ctx)
}
return ctx.state return ctx.state
} }
@ -34,11 +31,20 @@ func (ctx *context) SetProtocolState(st interface{}) {
ctx.state = st ctx.state = st
} }
func (ctx *context) IsProtocolStart() bool {
return ctx.state == nil
}
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
if ctx.endStatus != protocol.StatusUnknown { if ctx.endStatus != protocol.StatusUnknown {
return return
} }
ctx.endStatus = st ctx.endStatus = st
if mf == nil {
mf = func(p *radius.Packet) *radius.Packet {
return p
}
}
ctx.endModifier = mf ctx.endModifier = mf
} }

View File

@ -4,11 +4,12 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/md5" "crypto/md5"
"encoding/base64" "encoding/base64"
"fmt"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/tls"
"layeh.com/radius" "layeh.com/radius"
"layeh.com/radius/rfc2865" "layeh.com/radius/rfc2865"
"layeh.com/radius/rfc2869" "layeh.com/radius/rfc2869"
@ -22,66 +23,42 @@ func sendErrorResponse(w radius.ResponseWriter, r *radius.Request) {
} }
} }
func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Request) { func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) {
rst := rfc2865.State_GetString(r.Packet) rst := rfc2865.State_GetString(r.Packet)
if rst == "" { if rst == "" {
rst = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(12)) rst = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(12))
} }
st := stm.GetEAPState(rst) p.state = rst
if st == nil {
log.Debug("EAP: blank state")
st = BlankState(stm.GetEAPSettings())
}
if len(st.ChallengesToOffer) < 1 {
log.Error("No more challenges to offer")
sendErrorResponse(w, r)
return
}
nextChallengeToOffer := st.ChallengesToOffer[0]
ctx := &context{ rp, err := p.handleInner(r)
req: r, rres := r.Response(radius.CodeAccessReject)
state: st.TypeState[nextChallengeToOffer], if err == nil {
log: log.WithField("type", nextChallengeToOffer), rres = p.endModifier(rres)
settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer], switch rp.code {
} case CodeFailure:
rres.Code = radius.CodeAccessReject
res := p.GetChallengeForType(ctx, nextChallengeToOffer) case CodeSuccess:
st.TypeState[nextChallengeToOffer] = ctx.GetProtocolState(nil)
stm.SetEAPState(rst, st)
rres := r.Response(radius.CodeAccessChallenge)
switch ctx.endStatus {
case protocol.StatusSuccess:
res.code = CodeSuccess
res.id -= 1
rres = ctx.endModifier(rres)
st.ChallengesToOffer = st.ChallengesToOffer[1:]
if len(st.ChallengesToOffer) < 1 {
rres.Code = radius.CodeAccessAccept rres.Code = radius.CodeAccessAccept
} }
case protocol.StatusError: } else {
res.code = CodeFailure rres.Code = radius.CodeAccessReject
res.id -= 1 log.WithError(err).Debug("Rejecting request")
st.ChallengesToOffer = st.ChallengesToOffer[1:]
rres = ctx.endModifier(rres)
if len(st.ChallengesToOffer) < 1 {
rres.Code = radius.CodeAccessReject
}
case protocol.StatusUnknown:
} }
rfc2865.State_SetString(rres, rst)
eapEncoded, err := res.Encode() rfc2865.State_SetString(rres, p.state)
eapEncoded, err := rp.Encode()
if err != nil { if err != nil {
log.WithError(err).Warning("failed to encode response") log.WithError(err).Warning("failed to encode response")
sendErrorResponse(w, r) sendErrorResponse(w, r)
return
} }
log.WithField("length", len(eapEncoded)).Debug("EAP: encapsulated challenge") log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.Payload)).Debug("EAP: encapsulated challenge")
rfc2869.EAPMessage_Set(rres, eapEncoded) rfc2869.EAPMessage_Set(rres, eapEncoded)
err = p.setMessageAuthenticator(rres) err = p.setMessageAuthenticator(rres)
if err != nil { if err != nil {
log.WithError(err).Warning("failed to send message authenticator") log.WithError(err).Warning("failed to send message authenticator")
sendErrorResponse(w, r) sendErrorResponse(w, r)
return
} }
err = w.Write(rres) err = w.Write(rres)
if err != nil { if err != nil {
@ -89,21 +66,75 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Req
} }
} }
func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet { func (p *Packet) handleInner(r *radius.Request) (*Packet, error) {
st := p.stm.GetEAPState(p.state)
if st == nil {
log.Debug("EAP: blank state")
st = BlankState(p.stm.GetEAPSettings())
}
nextChallengeToOffer, err := st.GetNextProtocol()
if err != nil {
return &Packet{
code: CodeFailure,
id: p.id,
}, err
}
if _, ok := p.Payload.(*legacy_nak.Payload); ok {
log.Debug("EAP: received NAK, trying next protocol")
st.ProtocolIndex += 1
p.stm.SetEAPState(p.state, st)
return p.handleInner(r)
}
np, _ := emptyPayload(p.stm, nextChallengeToOffer)
ctx := &context{
req: r,
state: st.TypeState[np.Type()],
log: log.WithField("type", fmt.Sprintf("%T", np)),
settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()],
}
ctx.log.Debug("EAP: Passing to protocol")
res := p.GetChallengeForType(ctx, np)
st.TypeState[np.Type()] = ctx.GetProtocolState()
p.stm.SetEAPState(p.state, st)
if ctx.endModifier != nil {
p.endModifier = ctx.endModifier
}
switch ctx.endStatus {
case protocol.StatusSuccess:
res.code = CodeSuccess
res.id -= 1
case protocol.StatusError:
res.code = CodeFailure
res.id -= 1
case protocol.StatusNextProtocol:
ctx.log.Debug("EAP: Protocol ended, starting next protocol")
st.ProtocolIndex += 1
p.stm.SetEAPState(p.state, st)
return p.handleInner(r)
case protocol.StatusUnknown:
}
return res, nil
}
func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *Packet {
res := &Packet{ res := &Packet{
code: CodeRequest, code: CodeRequest,
id: p.id + 1, id: p.id + 1,
msgType: t, msgType: np.Type(),
} }
var payload any var payload any
switch t { if ctx.IsProtocolStart() {
case tls.TypeTLS: p.Payload = np
if _, ok := p.Payload.(*tls.Payload); !ok { p.Payload.Decode(p.rawPayload)
p.Payload = &tls.Payload{}
p.Payload.Decode(p.rawPayload)
}
payload = p.Payload.(*tls.Payload).Handle(ctx)
} }
payload = p.Payload.Handle(ctx)
if payload != nil { if payload != nil {
res.Payload = payload.(protocol.Payload) res.Payload = payload.(protocol.Payload)
} }

View File

@ -0,0 +1,37 @@
package identity
import "goauthentik.io/internal/outpost/radius/eap/protocol"
const TypeIdentity protocol.Type = 1
func Protocol() protocol.Payload {
return &Payload{}
}
type Payload struct {
Identity string
}
func (ip *Payload) Type() protocol.Type {
return TypeIdentity
}
func (ip *Payload) Decode(raw []byte) error {
ip.Identity = string(raw)
return nil
}
func (ip *Payload) Encode() ([]byte, error) {
return []byte{}, nil
}
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart() {
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
}
return nil
}
func (ip *Payload) Offerable() bool {
return false
}

View File

@ -0,0 +1,37 @@
package legacy_nak
import "goauthentik.io/internal/outpost/radius/eap/protocol"
const TypeLegacyNAK protocol.Type = 3
func Protocol() protocol.Payload {
return &Payload{}
}
type Payload struct {
DesiredType protocol.Type
}
func (ln *Payload) Type() protocol.Type {
return TypeLegacyNAK
}
func (ln *Payload) Decode(raw []byte) error {
ln.DesiredType = protocol.Type(raw[0])
return nil
}
func (ln *Payload) Encode() ([]byte, error) {
return []byte{byte(ln.DesiredType)}, nil
}
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart() {
ctx.EndInnerProtocol(protocol.StatusError, nil)
}
return nil
}
func (ln *Payload) Offerable() bool {
return false
}

View File

@ -3,11 +3,12 @@ package eap
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/debug"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/tls" "layeh.com/radius"
) )
type Code uint8 type Code uint8
@ -26,22 +27,30 @@ type Packet struct {
msgType protocol.Type msgType protocol.Type
rawPayload []byte rawPayload []byte
Payload protocol.Payload Payload protocol.Payload
stm StateManager
state string
endModifier func(p *radius.Packet) *radius.Packet
} }
type PayloadWriter struct{} type PayloadWriter struct{}
func emptyPayload(t protocol.Type) protocol.Payload { func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) {
switch t { for _, cons := range stm.GetEAPSettings().Protocols {
case protocol.TypeIdentity: if np := cons(); np.Type() == t {
return &IdentityPayload{} return np, nil
case tls.TypeTLS: }
return &tls.Payload{}
} }
return nil return nil, fmt.Errorf("unsupported EAP type %d", t)
} }
func Decode(raw []byte) (*Packet, error) { func Decode(stm StateManager, raw []byte) (*Packet, error) {
packet := &Packet{} packet := &Packet{
stm: stm,
endModifier: func(p *radius.Packet) *radius.Packet {
return p
},
}
packet.code = Code(raw[0]) packet.code = Code(raw[0])
packet.id = raw[1] packet.id = raw[1]
packet.length = binary.BigEndian.Uint16(raw[2:]) packet.length = binary.BigEndian.Uint16(raw[2:])
@ -51,10 +60,14 @@ func Decode(raw []byte) (*Packet, error) {
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) { if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
packet.msgType = protocol.Type(raw[4]) packet.msgType = protocol.Type(raw[4])
} }
packet.Payload = emptyPayload(packet.msgType) p, err := emptyPayload(stm, packet.msgType)
if err != nil {
return nil, err
}
packet.Payload = p
packet.rawPayload = raw[5:] packet.rawPayload = raw[5:]
log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw") log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw")
err := packet.Payload.Decode(raw[5:]) err = packet.Payload.Decode(raw[5:])
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,14 +0,0 @@
package eap
type IdentityPayload struct {
Identity string
}
func (ip *IdentityPayload) Decode(raw []byte) error {
ip.Identity = string(raw)
return nil
}
func (ip *IdentityPayload) Encode() ([]byte, error) {
return []byte{}, nil
}

View File

@ -11,15 +11,17 @@ const (
StatusUnknown Status = iota StatusUnknown Status = iota
StatusSuccess StatusSuccess
StatusError StatusError
StatusNextProtocol
) )
type Context interface { type Context interface {
Packet() *radius.Request Packet() *radius.Request
ProtocolSettings() interface{} ProtocolSettings() interface{}
GetProtocolState(def func(Context) interface{}) interface{} GetProtocolState() interface{}
SetProtocolState(interface{}) SetProtocolState(interface{})
IsProtocolStart() bool
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet) EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
Log() *log.Entry Log() *log.Entry

View File

@ -3,11 +3,9 @@ package protocol
type Payload interface { type Payload interface {
Decode(raw []byte) error Decode(raw []byte) error
Encode() ([]byte, error) Encode() ([]byte, error)
Handle(ctx Context) Payload
Type() Type
Offerable() bool
} }
type Type uint8 type Type uint8
const (
TypeIdentity Type = 1
TypeMD5Challenge Type = 4
)

View File

@ -1,13 +1,17 @@
package eap package eap
import ( import (
"errors"
"slices" "slices"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
) )
type ProtocolConstructor func() protocol.Payload
type Settings struct { type Settings struct {
ProtocolsToOffer []protocol.Type Protocols []ProtocolConstructor
ProtocolPriority []protocol.Type
ProtocolSettings map[protocol.Type]interface{} ProtocolSettings map[protocol.Type]interface{}
} }
@ -18,13 +22,23 @@ type StateManager interface {
} }
type State struct { type State struct {
ChallengesToOffer []protocol.Type Protocols []ProtocolConstructor
TypeState map[protocol.Type]any ProtocolIndex int
ProtocolPriority []protocol.Type
TypeState map[protocol.Type]any
}
func (st *State) GetNextProtocol() (protocol.Type, error) {
if st.ProtocolIndex >= len(st.ProtocolPriority) {
return protocol.Type(0), errors.New("no more protocols to offer")
}
return st.ProtocolPriority[st.ProtocolIndex], nil
} }
func BlankState(settings Settings) *State { func BlankState(settings Settings) *State {
return &State{ return &State{
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer), Protocols: slices.Clone(settings.Protocols),
TypeState: map[protocol.Type]any{}, ProtocolPriority: slices.Clone(settings.ProtocolPriority),
TypeState: map[protocol.Type]any{},
} }
} }

View File

@ -21,6 +21,10 @@ const staleConnectionTimeout = 10
const TypeTLS protocol.Type = 13 const TypeTLS protocol.Type = 13
func Protocol() protocol.Payload {
return &Payload{}
}
type Payload struct { type Payload struct {
Flags Flag Flags Flag
Length uint32 Length uint32
@ -29,6 +33,14 @@ type Payload struct {
st *State st *State
} }
func (p *Payload) Type() protocol.Type {
return TypeTLS
}
func (p *Payload) Offerable() bool {
return true
}
func (p *Payload) Decode(raw []byte) error { func (p *Payload) Decode(raw []byte) error {
p.Flags = Flag(raw[0]) p.Flags = Flag(raw[0])
raw = raw[1:] raw = raw[1:]
@ -65,15 +77,16 @@ func (p *Payload) Encode() ([]byte, error) {
} }
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
p.st = ctx.GetProtocolState(NewState).(*State) defer func() {
defer ctx.SetProtocolState(p.st) ctx.SetProtocolState(p.st)
if !p.st.HasStarted { }()
ctx.Log().Debug("TLS: handshake starting") if ctx.IsProtocolStart() {
p.st.HasStarted = true p.st = NewState(ctx).(*State)
return &Payload{ return &Payload{
Flags: FlagTLSStart, Flags: FlagTLSStart,
} }
} }
p.st = ctx.GetProtocolState().(*State)
if p.st.TLS == nil { if p.st.TLS == nil {
p.tlsInit(ctx) p.tlsInit(ctx)

View File

@ -8,7 +8,6 @@ import (
) )
type State struct { type State struct {
HasStarted bool
RemainingChunks [][]byte RemainingChunks [][]byte
HandshakeDone bool HandshakeDone bool
FinalStatus protocol.Status FinalStatus protocol.Status

View File

@ -12,6 +12,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow" "goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/radius/eap" "goauthentik.io/internal/outpost/radius/eap"
"goauthentik.io/internal/outpost/radius/eap/identity"
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/tls" "goauthentik.io/internal/outpost/radius/eap/tls"
"goauthentik.io/internal/outpost/radius/metrics" "goauthentik.io/internal/outpost/radius/metrics"
@ -111,12 +113,12 @@ func (rs *RadiusServer) Handle_AccessRequest_PAP(w radius.ResponseWriter, r *Rad
func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *RadiusRequest) { func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *RadiusRequest) {
er := rfc2869.EAPMessage_Get(r.Packet) er := rfc2869.EAPMessage_Get(r.Packet)
ep, err := eap.Decode(er) ep, err := eap.Decode(r.pi, er)
if err != nil { if err != nil {
rs.log.WithError(err).Warning("failed to parse EAP packet") rs.log.WithError(err).Warning("failed to parse EAP packet")
return return
} }
ep.Handle(r.pi, w, r.Request) ep.HandleRadiusPacket(w, r.Request)
} }
func (pi *ProviderInstance) GetEAPState(key string) *eap.State { func (pi *ProviderInstance) GetEAPState(key string) *eap.State {
@ -128,22 +130,30 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
} }
func (pi *ProviderInstance) GetEAPSettings() eap.Settings { func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
protocols := []eap.ProtocolConstructor{
identity.Protocol,
legacy_nak.Protocol,
}
certId := pi.certId certId := pi.certId
if certId == "" { if certId == "" {
return eap.Settings{ return eap.Settings{
ProtocolsToOffer: []protocol.Type{}, Protocols: protocols,
} }
} }
cert := pi.s.cryptoStore.Get(certId) cert := pi.s.cryptoStore.Get(certId)
if cert == nil { if cert == nil {
return eap.Settings{ return eap.Settings{
ProtocolsToOffer: []protocol.Type{}, Protocols: protocols,
} }
} }
return eap.Settings{ return eap.Settings{
ProtocolsToOffer: []protocol.Type{tls.TypeTLS}, Protocols: append(protocols, tls.Protocol),
ProtocolPriority: []protocol.Type{
tls.TypeTLS,
},
ProtocolSettings: map[protocol.Type]interface{}{ ProtocolSettings: map[protocol.Type]interface{}{
tls.TypeTLS: tls.Settings{ tls.TypeTLS: tls.Settings{
Config: &ttls.Config{ Config: &ttls.Config{