refactor v1, start support for more protocols and implement nak
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -23,10 +23,7 @@ func (ctx context) ProtocolSettings() interface{} {
|
||||
return ctx.settings
|
||||
}
|
||||
|
||||
func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
|
||||
if ctx.state == nil {
|
||||
ctx.state = def(ctx)
|
||||
}
|
||||
func (ctx *context) GetProtocolState() interface{} {
|
||||
return ctx.state
|
||||
}
|
||||
|
||||
@ -34,11 +31,20 @@ func (ctx *context) SetProtocolState(st interface{}) {
|
||||
ctx.state = st
|
||||
}
|
||||
|
||||
func (ctx *context) IsProtocolStart() bool {
|
||||
return ctx.state == nil
|
||||
}
|
||||
|
||||
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
||||
if ctx.endStatus != protocol.StatusUnknown {
|
||||
return
|
||||
}
|
||||
ctx.endStatus = st
|
||||
if mf == nil {
|
||||
mf = func(p *radius.Packet) *radius.Packet {
|
||||
return p
|
||||
}
|
||||
}
|
||||
ctx.endModifier = mf
|
||||
}
|
||||
|
||||
|
||||
@ -4,11 +4,12 @@ import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||
"layeh.com/radius"
|
||||
"layeh.com/radius/rfc2865"
|
||||
"layeh.com/radius/rfc2869"
|
||||
@ -22,66 +23,42 @@ func sendErrorResponse(w radius.ResponseWriter, r *radius.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Request) {
|
||||
func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) {
|
||||
rst := rfc2865.State_GetString(r.Packet)
|
||||
if rst == "" {
|
||||
rst = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(12))
|
||||
}
|
||||
st := stm.GetEAPState(rst)
|
||||
if st == nil {
|
||||
log.Debug("EAP: blank state")
|
||||
st = BlankState(stm.GetEAPSettings())
|
||||
}
|
||||
if len(st.ChallengesToOffer) < 1 {
|
||||
log.Error("No more challenges to offer")
|
||||
sendErrorResponse(w, r)
|
||||
return
|
||||
}
|
||||
nextChallengeToOffer := st.ChallengesToOffer[0]
|
||||
p.state = rst
|
||||
|
||||
ctx := &context{
|
||||
req: r,
|
||||
state: st.TypeState[nextChallengeToOffer],
|
||||
log: log.WithField("type", nextChallengeToOffer),
|
||||
settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer],
|
||||
}
|
||||
|
||||
res := p.GetChallengeForType(ctx, nextChallengeToOffer)
|
||||
st.TypeState[nextChallengeToOffer] = ctx.GetProtocolState(nil)
|
||||
stm.SetEAPState(rst, st)
|
||||
|
||||
rres := r.Response(radius.CodeAccessChallenge)
|
||||
switch ctx.endStatus {
|
||||
case protocol.StatusSuccess:
|
||||
res.code = CodeSuccess
|
||||
res.id -= 1
|
||||
rres = ctx.endModifier(rres)
|
||||
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
||||
if len(st.ChallengesToOffer) < 1 {
|
||||
rp, err := p.handleInner(r)
|
||||
rres := r.Response(radius.CodeAccessReject)
|
||||
if err == nil {
|
||||
rres = p.endModifier(rres)
|
||||
switch rp.code {
|
||||
case CodeFailure:
|
||||
rres.Code = radius.CodeAccessReject
|
||||
case CodeSuccess:
|
||||
rres.Code = radius.CodeAccessAccept
|
||||
}
|
||||
case protocol.StatusError:
|
||||
res.code = CodeFailure
|
||||
res.id -= 1
|
||||
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
||||
rres = ctx.endModifier(rres)
|
||||
if len(st.ChallengesToOffer) < 1 {
|
||||
} else {
|
||||
rres.Code = radius.CodeAccessReject
|
||||
log.WithError(err).Debug("Rejecting request")
|
||||
}
|
||||
case protocol.StatusUnknown:
|
||||
}
|
||||
rfc2865.State_SetString(rres, rst)
|
||||
eapEncoded, err := res.Encode()
|
||||
|
||||
rfc2865.State_SetString(rres, p.state)
|
||||
eapEncoded, err := rp.Encode()
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("failed to encode response")
|
||||
sendErrorResponse(w, r)
|
||||
return
|
||||
}
|
||||
log.WithField("length", len(eapEncoded)).Debug("EAP: encapsulated challenge")
|
||||
log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.Payload)).Debug("EAP: encapsulated challenge")
|
||||
rfc2869.EAPMessage_Set(rres, eapEncoded)
|
||||
err = p.setMessageAuthenticator(rres)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("failed to send message authenticator")
|
||||
sendErrorResponse(w, r)
|
||||
return
|
||||
}
|
||||
err = w.Write(rres)
|
||||
if err != nil {
|
||||
@ -89,21 +66,75 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Req
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet {
|
||||
func (p *Packet) handleInner(r *radius.Request) (*Packet, error) {
|
||||
st := p.stm.GetEAPState(p.state)
|
||||
if st == nil {
|
||||
log.Debug("EAP: blank state")
|
||||
st = BlankState(p.stm.GetEAPSettings())
|
||||
}
|
||||
|
||||
nextChallengeToOffer, err := st.GetNextProtocol()
|
||||
if err != nil {
|
||||
return &Packet{
|
||||
code: CodeFailure,
|
||||
id: p.id,
|
||||
}, err
|
||||
}
|
||||
|
||||
if _, ok := p.Payload.(*legacy_nak.Payload); ok {
|
||||
log.Debug("EAP: received NAK, trying next protocol")
|
||||
st.ProtocolIndex += 1
|
||||
p.stm.SetEAPState(p.state, st)
|
||||
return p.handleInner(r)
|
||||
}
|
||||
|
||||
np, _ := emptyPayload(p.stm, nextChallengeToOffer)
|
||||
|
||||
ctx := &context{
|
||||
req: r,
|
||||
state: st.TypeState[np.Type()],
|
||||
log: log.WithField("type", fmt.Sprintf("%T", np)),
|
||||
settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()],
|
||||
}
|
||||
ctx.log.Debug("EAP: Passing to protocol")
|
||||
|
||||
res := p.GetChallengeForType(ctx, np)
|
||||
st.TypeState[np.Type()] = ctx.GetProtocolState()
|
||||
p.stm.SetEAPState(p.state, st)
|
||||
|
||||
if ctx.endModifier != nil {
|
||||
p.endModifier = ctx.endModifier
|
||||
}
|
||||
|
||||
switch ctx.endStatus {
|
||||
case protocol.StatusSuccess:
|
||||
res.code = CodeSuccess
|
||||
res.id -= 1
|
||||
case protocol.StatusError:
|
||||
res.code = CodeFailure
|
||||
res.id -= 1
|
||||
case protocol.StatusNextProtocol:
|
||||
ctx.log.Debug("EAP: Protocol ended, starting next protocol")
|
||||
st.ProtocolIndex += 1
|
||||
p.stm.SetEAPState(p.state, st)
|
||||
return p.handleInner(r)
|
||||
case protocol.StatusUnknown:
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *Packet {
|
||||
res := &Packet{
|
||||
code: CodeRequest,
|
||||
id: p.id + 1,
|
||||
msgType: t,
|
||||
msgType: np.Type(),
|
||||
}
|
||||
var payload any
|
||||
switch t {
|
||||
case tls.TypeTLS:
|
||||
if _, ok := p.Payload.(*tls.Payload); !ok {
|
||||
p.Payload = &tls.Payload{}
|
||||
if ctx.IsProtocolStart() {
|
||||
p.Payload = np
|
||||
p.Payload.Decode(p.rawPayload)
|
||||
}
|
||||
payload = p.Payload.(*tls.Payload).Handle(ctx)
|
||||
}
|
||||
payload = p.Payload.Handle(ctx)
|
||||
if payload != nil {
|
||||
res.Payload = payload.(protocol.Payload)
|
||||
}
|
||||
|
||||
37
internal/outpost/radius/eap/identity/payload.go
Normal file
37
internal/outpost/radius/eap/identity/payload.go
Normal file
@ -0,0 +1,37 @@
|
||||
package identity
|
||||
|
||||
import "goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
|
||||
const TypeIdentity protocol.Type = 1
|
||||
|
||||
func Protocol() protocol.Payload {
|
||||
return &Payload{}
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
Identity string
|
||||
}
|
||||
|
||||
func (ip *Payload) Type() protocol.Type {
|
||||
return TypeIdentity
|
||||
}
|
||||
|
||||
func (ip *Payload) Decode(raw []byte) error {
|
||||
ip.Identity = string(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *Payload) Encode() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart() {
|
||||
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *Payload) Offerable() bool {
|
||||
return false
|
||||
}
|
||||
37
internal/outpost/radius/eap/legacy_nak/payload.go
Normal file
37
internal/outpost/radius/eap/legacy_nak/payload.go
Normal file
@ -0,0 +1,37 @@
|
||||
package legacy_nak
|
||||
|
||||
import "goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
|
||||
const TypeLegacyNAK protocol.Type = 3
|
||||
|
||||
func Protocol() protocol.Payload {
|
||||
return &Payload{}
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
DesiredType protocol.Type
|
||||
}
|
||||
|
||||
func (ln *Payload) Type() protocol.Type {
|
||||
return TypeLegacyNAK
|
||||
}
|
||||
|
||||
func (ln *Payload) Decode(raw []byte) error {
|
||||
ln.DesiredType = protocol.Type(raw[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *Payload) Encode() ([]byte, error) {
|
||||
return []byte{byte(ln.DesiredType)}, nil
|
||||
}
|
||||
|
||||
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart() {
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *Payload) Offerable() bool {
|
||||
return false
|
||||
}
|
||||
@ -3,11 +3,12 @@ package eap
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/radius/eap/debug"
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||
"layeh.com/radius"
|
||||
)
|
||||
|
||||
type Code uint8
|
||||
@ -26,22 +27,30 @@ type Packet struct {
|
||||
msgType protocol.Type
|
||||
rawPayload []byte
|
||||
Payload protocol.Payload
|
||||
|
||||
stm StateManager
|
||||
state string
|
||||
endModifier func(p *radius.Packet) *radius.Packet
|
||||
}
|
||||
|
||||
type PayloadWriter struct{}
|
||||
|
||||
func emptyPayload(t protocol.Type) protocol.Payload {
|
||||
switch t {
|
||||
case protocol.TypeIdentity:
|
||||
return &IdentityPayload{}
|
||||
case tls.TypeTLS:
|
||||
return &tls.Payload{}
|
||||
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) {
|
||||
for _, cons := range stm.GetEAPSettings().Protocols {
|
||||
if np := cons(); np.Type() == t {
|
||||
return np, nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported EAP type %d", t)
|
||||
}
|
||||
|
||||
func Decode(raw []byte) (*Packet, error) {
|
||||
packet := &Packet{}
|
||||
func Decode(stm StateManager, raw []byte) (*Packet, error) {
|
||||
packet := &Packet{
|
||||
stm: stm,
|
||||
endModifier: func(p *radius.Packet) *radius.Packet {
|
||||
return p
|
||||
},
|
||||
}
|
||||
packet.code = Code(raw[0])
|
||||
packet.id = raw[1]
|
||||
packet.length = binary.BigEndian.Uint16(raw[2:])
|
||||
@ -51,10 +60,14 @@ func Decode(raw []byte) (*Packet, error) {
|
||||
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
|
||||
packet.msgType = protocol.Type(raw[4])
|
||||
}
|
||||
packet.Payload = emptyPayload(packet.msgType)
|
||||
p, err := emptyPayload(stm, packet.msgType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packet.Payload = p
|
||||
packet.rawPayload = raw[5:]
|
||||
log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw")
|
||||
err := packet.Payload.Decode(raw[5:])
|
||||
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw")
|
||||
err = packet.Payload.Decode(raw[5:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
package eap
|
||||
|
||||
type IdentityPayload struct {
|
||||
Identity string
|
||||
}
|
||||
|
||||
func (ip *IdentityPayload) Decode(raw []byte) error {
|
||||
ip.Identity = string(raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *IdentityPayload) Encode() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
@ -11,15 +11,17 @@ const (
|
||||
StatusUnknown Status = iota
|
||||
StatusSuccess
|
||||
StatusError
|
||||
StatusNextProtocol
|
||||
)
|
||||
|
||||
type Context interface {
|
||||
Packet() *radius.Request
|
||||
|
||||
ProtocolSettings() interface{}
|
||||
GetProtocolState(def func(Context) interface{}) interface{}
|
||||
GetProtocolState() interface{}
|
||||
SetProtocolState(interface{})
|
||||
|
||||
IsProtocolStart() bool
|
||||
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
||||
|
||||
Log() *log.Entry
|
||||
|
||||
@ -3,11 +3,9 @@ package protocol
|
||||
type Payload interface {
|
||||
Decode(raw []byte) error
|
||||
Encode() ([]byte, error)
|
||||
Handle(ctx Context) Payload
|
||||
Type() Type
|
||||
Offerable() bool
|
||||
}
|
||||
|
||||
type Type uint8
|
||||
|
||||
const (
|
||||
TypeIdentity Type = 1
|
||||
TypeMD5Challenge Type = 4
|
||||
)
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
package eap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
)
|
||||
|
||||
type ProtocolConstructor func() protocol.Payload
|
||||
|
||||
type Settings struct {
|
||||
ProtocolsToOffer []protocol.Type
|
||||
Protocols []ProtocolConstructor
|
||||
ProtocolPriority []protocol.Type
|
||||
ProtocolSettings map[protocol.Type]interface{}
|
||||
}
|
||||
|
||||
@ -18,13 +22,23 @@ type StateManager interface {
|
||||
}
|
||||
|
||||
type State struct {
|
||||
ChallengesToOffer []protocol.Type
|
||||
Protocols []ProtocolConstructor
|
||||
ProtocolIndex int
|
||||
ProtocolPriority []protocol.Type
|
||||
TypeState map[protocol.Type]any
|
||||
}
|
||||
|
||||
func (st *State) GetNextProtocol() (protocol.Type, error) {
|
||||
if st.ProtocolIndex >= len(st.ProtocolPriority) {
|
||||
return protocol.Type(0), errors.New("no more protocols to offer")
|
||||
}
|
||||
return st.ProtocolPriority[st.ProtocolIndex], nil
|
||||
}
|
||||
|
||||
func BlankState(settings Settings) *State {
|
||||
return &State{
|
||||
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
|
||||
Protocols: slices.Clone(settings.Protocols),
|
||||
ProtocolPriority: slices.Clone(settings.ProtocolPriority),
|
||||
TypeState: map[protocol.Type]any{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,6 +21,10 @@ const staleConnectionTimeout = 10
|
||||
|
||||
const TypeTLS protocol.Type = 13
|
||||
|
||||
func Protocol() protocol.Payload {
|
||||
return &Payload{}
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
Flags Flag
|
||||
Length uint32
|
||||
@ -29,6 +33,14 @@ type Payload struct {
|
||||
st *State
|
||||
}
|
||||
|
||||
func (p *Payload) Type() protocol.Type {
|
||||
return TypeTLS
|
||||
}
|
||||
|
||||
func (p *Payload) Offerable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Payload) Decode(raw []byte) error {
|
||||
p.Flags = Flag(raw[0])
|
||||
raw = raw[1:]
|
||||
@ -65,15 +77,16 @@ func (p *Payload) Encode() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
p.st = ctx.GetProtocolState(NewState).(*State)
|
||||
defer ctx.SetProtocolState(p.st)
|
||||
if !p.st.HasStarted {
|
||||
ctx.Log().Debug("TLS: handshake starting")
|
||||
p.st.HasStarted = true
|
||||
defer func() {
|
||||
ctx.SetProtocolState(p.st)
|
||||
}()
|
||||
if ctx.IsProtocolStart() {
|
||||
p.st = NewState(ctx).(*State)
|
||||
return &Payload{
|
||||
Flags: FlagTLSStart,
|
||||
}
|
||||
}
|
||||
p.st = ctx.GetProtocolState().(*State)
|
||||
|
||||
if p.st.TLS == nil {
|
||||
p.tlsInit(ctx)
|
||||
|
||||
@ -8,7 +8,6 @@ import (
|
||||
)
|
||||
|
||||
type State struct {
|
||||
HasStarted bool
|
||||
RemainingChunks [][]byte
|
||||
HandshakeDone bool
|
||||
FinalStatus protocol.Status
|
||||
|
||||
@ -12,6 +12,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/flow"
|
||||
"goauthentik.io/internal/outpost/radius/eap"
|
||||
"goauthentik.io/internal/outpost/radius/eap/identity"
|
||||
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||
"goauthentik.io/internal/outpost/radius/metrics"
|
||||
@ -111,12 +113,12 @@ func (rs *RadiusServer) Handle_AccessRequest_PAP(w radius.ResponseWriter, r *Rad
|
||||
|
||||
func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *RadiusRequest) {
|
||||
er := rfc2869.EAPMessage_Get(r.Packet)
|
||||
ep, err := eap.Decode(er)
|
||||
ep, err := eap.Decode(r.pi, er)
|
||||
if err != nil {
|
||||
rs.log.WithError(err).Warning("failed to parse EAP packet")
|
||||
return
|
||||
}
|
||||
ep.Handle(r.pi, w, r.Request)
|
||||
ep.HandleRadiusPacket(w, r.Request)
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetEAPState(key string) *eap.State {
|
||||
@ -128,22 +130,30 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
||||
protocols := []eap.ProtocolConstructor{
|
||||
identity.Protocol,
|
||||
legacy_nak.Protocol,
|
||||
}
|
||||
|
||||
certId := pi.certId
|
||||
if certId == "" {
|
||||
return eap.Settings{
|
||||
ProtocolsToOffer: []protocol.Type{},
|
||||
Protocols: protocols,
|
||||
}
|
||||
}
|
||||
|
||||
cert := pi.s.cryptoStore.Get(certId)
|
||||
if cert == nil {
|
||||
return eap.Settings{
|
||||
ProtocolsToOffer: []protocol.Type{},
|
||||
Protocols: protocols,
|
||||
}
|
||||
}
|
||||
|
||||
return eap.Settings{
|
||||
ProtocolsToOffer: []protocol.Type{tls.TypeTLS},
|
||||
Protocols: append(protocols, tls.Protocol),
|
||||
ProtocolPriority: []protocol.Type{
|
||||
tls.TypeTLS,
|
||||
},
|
||||
ProtocolSettings: map[protocol.Type]interface{}{
|
||||
tls.TypeTLS: tls.Settings{
|
||||
Config: &ttls.Config{
|
||||
|
||||
Reference in New Issue
Block a user