refactor v1, start support for more protocols and implement nak

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-20 22:39:14 +02:00
parent 8cf8f1e199
commit b6686cff14
12 changed files with 252 additions and 106 deletions

View File

@ -23,10 +23,7 @@ func (ctx context) ProtocolSettings() interface{} {
return ctx.settings
}
func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
if ctx.state == nil {
ctx.state = def(ctx)
}
func (ctx *context) GetProtocolState() interface{} {
return ctx.state
}
@ -34,11 +31,20 @@ func (ctx *context) SetProtocolState(st interface{}) {
ctx.state = st
}
func (ctx *context) IsProtocolStart() bool {
return ctx.state == nil
}
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
if ctx.endStatus != protocol.StatusUnknown {
return
}
ctx.endStatus = st
if mf == nil {
mf = func(p *radius.Packet) *radius.Packet {
return p
}
}
ctx.endModifier = mf
}

View File

@ -4,11 +4,12 @@ import (
"crypto/hmac"
"crypto/md5"
"encoding/base64"
"fmt"
"github.com/gorilla/securecookie"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
"goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/tls"
"layeh.com/radius"
"layeh.com/radius/rfc2865"
"layeh.com/radius/rfc2869"
@ -22,66 +23,42 @@ func sendErrorResponse(w radius.ResponseWriter, r *radius.Request) {
}
}
func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Request) {
func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) {
rst := rfc2865.State_GetString(r.Packet)
if rst == "" {
rst = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(12))
}
st := stm.GetEAPState(rst)
if st == nil {
log.Debug("EAP: blank state")
st = BlankState(stm.GetEAPSettings())
}
if len(st.ChallengesToOffer) < 1 {
log.Error("No more challenges to offer")
sendErrorResponse(w, r)
return
}
nextChallengeToOffer := st.ChallengesToOffer[0]
p.state = rst
ctx := &context{
req: r,
state: st.TypeState[nextChallengeToOffer],
log: log.WithField("type", nextChallengeToOffer),
settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer],
}
res := p.GetChallengeForType(ctx, nextChallengeToOffer)
st.TypeState[nextChallengeToOffer] = ctx.GetProtocolState(nil)
stm.SetEAPState(rst, st)
rres := r.Response(radius.CodeAccessChallenge)
switch ctx.endStatus {
case protocol.StatusSuccess:
res.code = CodeSuccess
res.id -= 1
rres = ctx.endModifier(rres)
st.ChallengesToOffer = st.ChallengesToOffer[1:]
if len(st.ChallengesToOffer) < 1 {
rp, err := p.handleInner(r)
rres := r.Response(radius.CodeAccessReject)
if err == nil {
rres = p.endModifier(rres)
switch rp.code {
case CodeFailure:
rres.Code = radius.CodeAccessReject
case CodeSuccess:
rres.Code = radius.CodeAccessAccept
}
case protocol.StatusError:
res.code = CodeFailure
res.id -= 1
st.ChallengesToOffer = st.ChallengesToOffer[1:]
rres = ctx.endModifier(rres)
if len(st.ChallengesToOffer) < 1 {
rres.Code = radius.CodeAccessReject
}
case protocol.StatusUnknown:
} else {
rres.Code = radius.CodeAccessReject
log.WithError(err).Debug("Rejecting request")
}
rfc2865.State_SetString(rres, rst)
eapEncoded, err := res.Encode()
rfc2865.State_SetString(rres, p.state)
eapEncoded, err := rp.Encode()
if err != nil {
log.WithError(err).Warning("failed to encode response")
sendErrorResponse(w, r)
return
}
log.WithField("length", len(eapEncoded)).Debug("EAP: encapsulated challenge")
log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.Payload)).Debug("EAP: encapsulated challenge")
rfc2869.EAPMessage_Set(rres, eapEncoded)
err = p.setMessageAuthenticator(rres)
if err != nil {
log.WithError(err).Warning("failed to send message authenticator")
sendErrorResponse(w, r)
return
}
err = w.Write(rres)
if err != nil {
@ -89,21 +66,75 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Req
}
}
func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet {
func (p *Packet) handleInner(r *radius.Request) (*Packet, error) {
st := p.stm.GetEAPState(p.state)
if st == nil {
log.Debug("EAP: blank state")
st = BlankState(p.stm.GetEAPSettings())
}
nextChallengeToOffer, err := st.GetNextProtocol()
if err != nil {
return &Packet{
code: CodeFailure,
id: p.id,
}, err
}
if _, ok := p.Payload.(*legacy_nak.Payload); ok {
log.Debug("EAP: received NAK, trying next protocol")
st.ProtocolIndex += 1
p.stm.SetEAPState(p.state, st)
return p.handleInner(r)
}
np, _ := emptyPayload(p.stm, nextChallengeToOffer)
ctx := &context{
req: r,
state: st.TypeState[np.Type()],
log: log.WithField("type", fmt.Sprintf("%T", np)),
settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()],
}
ctx.log.Debug("EAP: Passing to protocol")
res := p.GetChallengeForType(ctx, np)
st.TypeState[np.Type()] = ctx.GetProtocolState()
p.stm.SetEAPState(p.state, st)
if ctx.endModifier != nil {
p.endModifier = ctx.endModifier
}
switch ctx.endStatus {
case protocol.StatusSuccess:
res.code = CodeSuccess
res.id -= 1
case protocol.StatusError:
res.code = CodeFailure
res.id -= 1
case protocol.StatusNextProtocol:
ctx.log.Debug("EAP: Protocol ended, starting next protocol")
st.ProtocolIndex += 1
p.stm.SetEAPState(p.state, st)
return p.handleInner(r)
case protocol.StatusUnknown:
}
return res, nil
}
func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *Packet {
res := &Packet{
code: CodeRequest,
id: p.id + 1,
msgType: t,
msgType: np.Type(),
}
var payload any
switch t {
case tls.TypeTLS:
if _, ok := p.Payload.(*tls.Payload); !ok {
p.Payload = &tls.Payload{}
p.Payload.Decode(p.rawPayload)
}
payload = p.Payload.(*tls.Payload).Handle(ctx)
if ctx.IsProtocolStart() {
p.Payload = np
p.Payload.Decode(p.rawPayload)
}
payload = p.Payload.Handle(ctx)
if payload != nil {
res.Payload = payload.(protocol.Payload)
}

View File

@ -0,0 +1,37 @@
package identity
import "goauthentik.io/internal/outpost/radius/eap/protocol"
const TypeIdentity protocol.Type = 1
func Protocol() protocol.Payload {
return &Payload{}
}
type Payload struct {
Identity string
}
func (ip *Payload) Type() protocol.Type {
return TypeIdentity
}
func (ip *Payload) Decode(raw []byte) error {
ip.Identity = string(raw)
return nil
}
func (ip *Payload) Encode() ([]byte, error) {
return []byte{}, nil
}
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart() {
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
}
return nil
}
func (ip *Payload) Offerable() bool {
return false
}

View File

@ -0,0 +1,37 @@
package legacy_nak
import "goauthentik.io/internal/outpost/radius/eap/protocol"
const TypeLegacyNAK protocol.Type = 3
func Protocol() protocol.Payload {
return &Payload{}
}
type Payload struct {
DesiredType protocol.Type
}
func (ln *Payload) Type() protocol.Type {
return TypeLegacyNAK
}
func (ln *Payload) Decode(raw []byte) error {
ln.DesiredType = protocol.Type(raw[0])
return nil
}
func (ln *Payload) Encode() ([]byte, error) {
return []byte{byte(ln.DesiredType)}, nil
}
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload {
if ctx.IsProtocolStart() {
ctx.EndInnerProtocol(protocol.StatusError, nil)
}
return nil
}
func (ln *Payload) Offerable() bool {
return false
}

View File

@ -3,11 +3,12 @@ package eap
import (
"encoding/binary"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/radius/eap/debug"
"goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/tls"
"layeh.com/radius"
)
type Code uint8
@ -26,22 +27,30 @@ type Packet struct {
msgType protocol.Type
rawPayload []byte
Payload protocol.Payload
stm StateManager
state string
endModifier func(p *radius.Packet) *radius.Packet
}
type PayloadWriter struct{}
func emptyPayload(t protocol.Type) protocol.Payload {
switch t {
case protocol.TypeIdentity:
return &IdentityPayload{}
case tls.TypeTLS:
return &tls.Payload{}
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) {
for _, cons := range stm.GetEAPSettings().Protocols {
if np := cons(); np.Type() == t {
return np, nil
}
}
return nil
return nil, fmt.Errorf("unsupported EAP type %d", t)
}
func Decode(raw []byte) (*Packet, error) {
packet := &Packet{}
func Decode(stm StateManager, raw []byte) (*Packet, error) {
packet := &Packet{
stm: stm,
endModifier: func(p *radius.Packet) *radius.Packet {
return p
},
}
packet.code = Code(raw[0])
packet.id = raw[1]
packet.length = binary.BigEndian.Uint16(raw[2:])
@ -51,10 +60,14 @@ func Decode(raw []byte) (*Packet, error) {
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
packet.msgType = protocol.Type(raw[4])
}
packet.Payload = emptyPayload(packet.msgType)
p, err := emptyPayload(stm, packet.msgType)
if err != nil {
return nil, err
}
packet.Payload = p
packet.rawPayload = raw[5:]
log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw")
err := packet.Payload.Decode(raw[5:])
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw")
err = packet.Payload.Decode(raw[5:])
if err != nil {
return nil, err
}

View File

@ -1,14 +0,0 @@
package eap
type IdentityPayload struct {
Identity string
}
func (ip *IdentityPayload) Decode(raw []byte) error {
ip.Identity = string(raw)
return nil
}
func (ip *IdentityPayload) Encode() ([]byte, error) {
return []byte{}, nil
}

View File

@ -11,15 +11,17 @@ const (
StatusUnknown Status = iota
StatusSuccess
StatusError
StatusNextProtocol
)
type Context interface {
Packet() *radius.Request
ProtocolSettings() interface{}
GetProtocolState(def func(Context) interface{}) interface{}
GetProtocolState() interface{}
SetProtocolState(interface{})
IsProtocolStart() bool
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
Log() *log.Entry

View File

@ -3,11 +3,9 @@ package protocol
type Payload interface {
Decode(raw []byte) error
Encode() ([]byte, error)
Handle(ctx Context) Payload
Type() Type
Offerable() bool
}
type Type uint8
const (
TypeIdentity Type = 1
TypeMD5Challenge Type = 4
)

View File

@ -1,13 +1,17 @@
package eap
import (
"errors"
"slices"
"goauthentik.io/internal/outpost/radius/eap/protocol"
)
type ProtocolConstructor func() protocol.Payload
type Settings struct {
ProtocolsToOffer []protocol.Type
Protocols []ProtocolConstructor
ProtocolPriority []protocol.Type
ProtocolSettings map[protocol.Type]interface{}
}
@ -18,13 +22,23 @@ type StateManager interface {
}
type State struct {
ChallengesToOffer []protocol.Type
TypeState map[protocol.Type]any
Protocols []ProtocolConstructor
ProtocolIndex int
ProtocolPriority []protocol.Type
TypeState map[protocol.Type]any
}
func (st *State) GetNextProtocol() (protocol.Type, error) {
if st.ProtocolIndex >= len(st.ProtocolPriority) {
return protocol.Type(0), errors.New("no more protocols to offer")
}
return st.ProtocolPriority[st.ProtocolIndex], nil
}
func BlankState(settings Settings) *State {
return &State{
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
TypeState: map[protocol.Type]any{},
Protocols: slices.Clone(settings.Protocols),
ProtocolPriority: slices.Clone(settings.ProtocolPriority),
TypeState: map[protocol.Type]any{},
}
}

View File

@ -21,6 +21,10 @@ const staleConnectionTimeout = 10
const TypeTLS protocol.Type = 13
func Protocol() protocol.Payload {
return &Payload{}
}
type Payload struct {
Flags Flag
Length uint32
@ -29,6 +33,14 @@ type Payload struct {
st *State
}
func (p *Payload) Type() protocol.Type {
return TypeTLS
}
func (p *Payload) Offerable() bool {
return true
}
func (p *Payload) Decode(raw []byte) error {
p.Flags = Flag(raw[0])
raw = raw[1:]
@ -65,15 +77,16 @@ func (p *Payload) Encode() ([]byte, error) {
}
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
p.st = ctx.GetProtocolState(NewState).(*State)
defer ctx.SetProtocolState(p.st)
if !p.st.HasStarted {
ctx.Log().Debug("TLS: handshake starting")
p.st.HasStarted = true
defer func() {
ctx.SetProtocolState(p.st)
}()
if ctx.IsProtocolStart() {
p.st = NewState(ctx).(*State)
return &Payload{
Flags: FlagTLSStart,
}
}
p.st = ctx.GetProtocolState().(*State)
if p.st.TLS == nil {
p.tlsInit(ctx)

View File

@ -8,7 +8,6 @@ import (
)
type State struct {
HasStarted bool
RemainingChunks [][]byte
HandshakeDone bool
FinalStatus protocol.Status

View File

@ -12,6 +12,8 @@ import (
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/radius/eap"
"goauthentik.io/internal/outpost/radius/eap/identity"
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
"goauthentik.io/internal/outpost/radius/eap/protocol"
"goauthentik.io/internal/outpost/radius/eap/tls"
"goauthentik.io/internal/outpost/radius/metrics"
@ -111,12 +113,12 @@ func (rs *RadiusServer) Handle_AccessRequest_PAP(w radius.ResponseWriter, r *Rad
func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *RadiusRequest) {
er := rfc2869.EAPMessage_Get(r.Packet)
ep, err := eap.Decode(er)
ep, err := eap.Decode(r.pi, er)
if err != nil {
rs.log.WithError(err).Warning("failed to parse EAP packet")
return
}
ep.Handle(r.pi, w, r.Request)
ep.HandleRadiusPacket(w, r.Request)
}
func (pi *ProviderInstance) GetEAPState(key string) *eap.State {
@ -128,22 +130,30 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
}
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
protocols := []eap.ProtocolConstructor{
identity.Protocol,
legacy_nak.Protocol,
}
certId := pi.certId
if certId == "" {
return eap.Settings{
ProtocolsToOffer: []protocol.Type{},
Protocols: protocols,
}
}
cert := pi.s.cryptoStore.Get(certId)
if cert == nil {
return eap.Settings{
ProtocolsToOffer: []protocol.Type{},
Protocols: protocols,
}
}
return eap.Settings{
ProtocolsToOffer: []protocol.Type{tls.TypeTLS},
Protocols: append(protocols, tls.Protocol),
ProtocolPriority: []protocol.Type{
tls.TypeTLS,
},
ProtocolSettings: map[protocol.Type]interface{}{
tls.TypeTLS: tls.Settings{
Config: &ttls.Config{