refactor v1, start support for more protocols and implement nak
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -23,10 +23,7 @@ func (ctx context) ProtocolSettings() interface{} {
|
|||||||
return ctx.settings
|
return ctx.settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
|
func (ctx *context) GetProtocolState() interface{} {
|
||||||
if ctx.state == nil {
|
|
||||||
ctx.state = def(ctx)
|
|
||||||
}
|
|
||||||
return ctx.state
|
return ctx.state
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,11 +31,20 @@ func (ctx *context) SetProtocolState(st interface{}) {
|
|||||||
ctx.state = st
|
ctx.state = st
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ctx *context) IsProtocolStart() bool {
|
||||||
|
return ctx.state == nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
||||||
if ctx.endStatus != protocol.StatusUnknown {
|
if ctx.endStatus != protocol.StatusUnknown {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx.endStatus = st
|
ctx.endStatus = st
|
||||||
|
if mf == nil {
|
||||||
|
mf = func(p *radius.Packet) *radius.Packet {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
ctx.endModifier = mf
|
ctx.endModifier = mf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,11 +4,12 @@ import (
|
|||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/gorilla/securecookie"
|
"github.com/gorilla/securecookie"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
|
||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
"layeh.com/radius/rfc2865"
|
"layeh.com/radius/rfc2865"
|
||||||
"layeh.com/radius/rfc2869"
|
"layeh.com/radius/rfc2869"
|
||||||
@ -22,66 +23,42 @@ func sendErrorResponse(w radius.ResponseWriter, r *radius.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Request) {
|
func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) {
|
||||||
rst := rfc2865.State_GetString(r.Packet)
|
rst := rfc2865.State_GetString(r.Packet)
|
||||||
if rst == "" {
|
if rst == "" {
|
||||||
rst = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(12))
|
rst = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(12))
|
||||||
}
|
}
|
||||||
st := stm.GetEAPState(rst)
|
p.state = rst
|
||||||
if st == nil {
|
|
||||||
log.Debug("EAP: blank state")
|
|
||||||
st = BlankState(stm.GetEAPSettings())
|
|
||||||
}
|
|
||||||
if len(st.ChallengesToOffer) < 1 {
|
|
||||||
log.Error("No more challenges to offer")
|
|
||||||
sendErrorResponse(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nextChallengeToOffer := st.ChallengesToOffer[0]
|
|
||||||
|
|
||||||
ctx := &context{
|
rp, err := p.handleInner(r)
|
||||||
req: r,
|
rres := r.Response(radius.CodeAccessReject)
|
||||||
state: st.TypeState[nextChallengeToOffer],
|
if err == nil {
|
||||||
log: log.WithField("type", nextChallengeToOffer),
|
rres = p.endModifier(rres)
|
||||||
settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer],
|
switch rp.code {
|
||||||
}
|
case CodeFailure:
|
||||||
|
rres.Code = radius.CodeAccessReject
|
||||||
res := p.GetChallengeForType(ctx, nextChallengeToOffer)
|
case CodeSuccess:
|
||||||
st.TypeState[nextChallengeToOffer] = ctx.GetProtocolState(nil)
|
|
||||||
stm.SetEAPState(rst, st)
|
|
||||||
|
|
||||||
rres := r.Response(radius.CodeAccessChallenge)
|
|
||||||
switch ctx.endStatus {
|
|
||||||
case protocol.StatusSuccess:
|
|
||||||
res.code = CodeSuccess
|
|
||||||
res.id -= 1
|
|
||||||
rres = ctx.endModifier(rres)
|
|
||||||
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
|
||||||
if len(st.ChallengesToOffer) < 1 {
|
|
||||||
rres.Code = radius.CodeAccessAccept
|
rres.Code = radius.CodeAccessAccept
|
||||||
}
|
}
|
||||||
case protocol.StatusError:
|
} else {
|
||||||
res.code = CodeFailure
|
rres.Code = radius.CodeAccessReject
|
||||||
res.id -= 1
|
log.WithError(err).Debug("Rejecting request")
|
||||||
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
|
||||||
rres = ctx.endModifier(rres)
|
|
||||||
if len(st.ChallengesToOffer) < 1 {
|
|
||||||
rres.Code = radius.CodeAccessReject
|
|
||||||
}
|
|
||||||
case protocol.StatusUnknown:
|
|
||||||
}
|
}
|
||||||
rfc2865.State_SetString(rres, rst)
|
|
||||||
eapEncoded, err := res.Encode()
|
rfc2865.State_SetString(rres, p.state)
|
||||||
|
eapEncoded, err := rp.Encode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Warning("failed to encode response")
|
log.WithError(err).Warning("failed to encode response")
|
||||||
sendErrorResponse(w, r)
|
sendErrorResponse(w, r)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
log.WithField("length", len(eapEncoded)).Debug("EAP: encapsulated challenge")
|
log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.Payload)).Debug("EAP: encapsulated challenge")
|
||||||
rfc2869.EAPMessage_Set(rres, eapEncoded)
|
rfc2869.EAPMessage_Set(rres, eapEncoded)
|
||||||
err = p.setMessageAuthenticator(rres)
|
err = p.setMessageAuthenticator(rres)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Warning("failed to send message authenticator")
|
log.WithError(err).Warning("failed to send message authenticator")
|
||||||
sendErrorResponse(w, r)
|
sendErrorResponse(w, r)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
err = w.Write(rres)
|
err = w.Write(rres)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -89,21 +66,75 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet {
|
func (p *Packet) handleInner(r *radius.Request) (*Packet, error) {
|
||||||
|
st := p.stm.GetEAPState(p.state)
|
||||||
|
if st == nil {
|
||||||
|
log.Debug("EAP: blank state")
|
||||||
|
st = BlankState(p.stm.GetEAPSettings())
|
||||||
|
}
|
||||||
|
|
||||||
|
nextChallengeToOffer, err := st.GetNextProtocol()
|
||||||
|
if err != nil {
|
||||||
|
return &Packet{
|
||||||
|
code: CodeFailure,
|
||||||
|
id: p.id,
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := p.Payload.(*legacy_nak.Payload); ok {
|
||||||
|
log.Debug("EAP: received NAK, trying next protocol")
|
||||||
|
st.ProtocolIndex += 1
|
||||||
|
p.stm.SetEAPState(p.state, st)
|
||||||
|
return p.handleInner(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
np, _ := emptyPayload(p.stm, nextChallengeToOffer)
|
||||||
|
|
||||||
|
ctx := &context{
|
||||||
|
req: r,
|
||||||
|
state: st.TypeState[np.Type()],
|
||||||
|
log: log.WithField("type", fmt.Sprintf("%T", np)),
|
||||||
|
settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()],
|
||||||
|
}
|
||||||
|
ctx.log.Debug("EAP: Passing to protocol")
|
||||||
|
|
||||||
|
res := p.GetChallengeForType(ctx, np)
|
||||||
|
st.TypeState[np.Type()] = ctx.GetProtocolState()
|
||||||
|
p.stm.SetEAPState(p.state, st)
|
||||||
|
|
||||||
|
if ctx.endModifier != nil {
|
||||||
|
p.endModifier = ctx.endModifier
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ctx.endStatus {
|
||||||
|
case protocol.StatusSuccess:
|
||||||
|
res.code = CodeSuccess
|
||||||
|
res.id -= 1
|
||||||
|
case protocol.StatusError:
|
||||||
|
res.code = CodeFailure
|
||||||
|
res.id -= 1
|
||||||
|
case protocol.StatusNextProtocol:
|
||||||
|
ctx.log.Debug("EAP: Protocol ended, starting next protocol")
|
||||||
|
st.ProtocolIndex += 1
|
||||||
|
p.stm.SetEAPState(p.state, st)
|
||||||
|
return p.handleInner(r)
|
||||||
|
case protocol.StatusUnknown:
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *Packet {
|
||||||
res := &Packet{
|
res := &Packet{
|
||||||
code: CodeRequest,
|
code: CodeRequest,
|
||||||
id: p.id + 1,
|
id: p.id + 1,
|
||||||
msgType: t,
|
msgType: np.Type(),
|
||||||
}
|
}
|
||||||
var payload any
|
var payload any
|
||||||
switch t {
|
if ctx.IsProtocolStart() {
|
||||||
case tls.TypeTLS:
|
p.Payload = np
|
||||||
if _, ok := p.Payload.(*tls.Payload); !ok {
|
p.Payload.Decode(p.rawPayload)
|
||||||
p.Payload = &tls.Payload{}
|
|
||||||
p.Payload.Decode(p.rawPayload)
|
|
||||||
}
|
|
||||||
payload = p.Payload.(*tls.Payload).Handle(ctx)
|
|
||||||
}
|
}
|
||||||
|
payload = p.Payload.Handle(ctx)
|
||||||
if payload != nil {
|
if payload != nil {
|
||||||
res.Payload = payload.(protocol.Payload)
|
res.Payload = payload.(protocol.Payload)
|
||||||
}
|
}
|
||||||
|
|||||||
37
internal/outpost/radius/eap/identity/payload.go
Normal file
37
internal/outpost/radius/eap/identity/payload.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package identity
|
||||||
|
|
||||||
|
import "goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
|
|
||||||
|
const TypeIdentity protocol.Type = 1
|
||||||
|
|
||||||
|
func Protocol() protocol.Payload {
|
||||||
|
return &Payload{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Payload struct {
|
||||||
|
Identity string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *Payload) Type() protocol.Type {
|
||||||
|
return TypeIdentity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *Payload) Decode(raw []byte) error {
|
||||||
|
ip.Identity = string(raw)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *Payload) Encode() ([]byte, error) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||||
|
if ctx.IsProtocolStart() {
|
||||||
|
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *Payload) Offerable() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
37
internal/outpost/radius/eap/legacy_nak/payload.go
Normal file
37
internal/outpost/radius/eap/legacy_nak/payload.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package legacy_nak
|
||||||
|
|
||||||
|
import "goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
|
|
||||||
|
const TypeLegacyNAK protocol.Type = 3
|
||||||
|
|
||||||
|
func Protocol() protocol.Payload {
|
||||||
|
return &Payload{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Payload struct {
|
||||||
|
DesiredType protocol.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *Payload) Type() protocol.Type {
|
||||||
|
return TypeLegacyNAK
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *Payload) Decode(raw []byte) error {
|
||||||
|
ln.DesiredType = protocol.Type(raw[0])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *Payload) Encode() ([]byte, error) {
|
||||||
|
return []byte{byte(ln.DesiredType)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||||
|
if ctx.IsProtocolStart() {
|
||||||
|
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *Payload) Offerable() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
@ -3,11 +3,12 @@ package eap
|
|||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/debug"
|
"goauthentik.io/internal/outpost/radius/eap/debug"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
"layeh.com/radius"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Code uint8
|
type Code uint8
|
||||||
@ -26,22 +27,30 @@ type Packet struct {
|
|||||||
msgType protocol.Type
|
msgType protocol.Type
|
||||||
rawPayload []byte
|
rawPayload []byte
|
||||||
Payload protocol.Payload
|
Payload protocol.Payload
|
||||||
|
|
||||||
|
stm StateManager
|
||||||
|
state string
|
||||||
|
endModifier func(p *radius.Packet) *radius.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
type PayloadWriter struct{}
|
type PayloadWriter struct{}
|
||||||
|
|
||||||
func emptyPayload(t protocol.Type) protocol.Payload {
|
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) {
|
||||||
switch t {
|
for _, cons := range stm.GetEAPSettings().Protocols {
|
||||||
case protocol.TypeIdentity:
|
if np := cons(); np.Type() == t {
|
||||||
return &IdentityPayload{}
|
return np, nil
|
||||||
case tls.TypeTLS:
|
}
|
||||||
return &tls.Payload{}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil, fmt.Errorf("unsupported EAP type %d", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Decode(raw []byte) (*Packet, error) {
|
func Decode(stm StateManager, raw []byte) (*Packet, error) {
|
||||||
packet := &Packet{}
|
packet := &Packet{
|
||||||
|
stm: stm,
|
||||||
|
endModifier: func(p *radius.Packet) *radius.Packet {
|
||||||
|
return p
|
||||||
|
},
|
||||||
|
}
|
||||||
packet.code = Code(raw[0])
|
packet.code = Code(raw[0])
|
||||||
packet.id = raw[1]
|
packet.id = raw[1]
|
||||||
packet.length = binary.BigEndian.Uint16(raw[2:])
|
packet.length = binary.BigEndian.Uint16(raw[2:])
|
||||||
@ -51,10 +60,14 @@ func Decode(raw []byte) (*Packet, error) {
|
|||||||
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
|
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
|
||||||
packet.msgType = protocol.Type(raw[4])
|
packet.msgType = protocol.Type(raw[4])
|
||||||
}
|
}
|
||||||
packet.Payload = emptyPayload(packet.msgType)
|
p, err := emptyPayload(stm, packet.msgType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
packet.Payload = p
|
||||||
packet.rawPayload = raw[5:]
|
packet.rawPayload = raw[5:]
|
||||||
log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw")
|
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw")
|
||||||
err := packet.Payload.Decode(raw[5:])
|
err = packet.Payload.Decode(raw[5:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,14 +0,0 @@
|
|||||||
package eap
|
|
||||||
|
|
||||||
type IdentityPayload struct {
|
|
||||||
Identity string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ip *IdentityPayload) Decode(raw []byte) error {
|
|
||||||
ip.Identity = string(raw)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ip *IdentityPayload) Encode() ([]byte, error) {
|
|
||||||
return []byte{}, nil
|
|
||||||
}
|
|
||||||
@ -11,15 +11,17 @@ const (
|
|||||||
StatusUnknown Status = iota
|
StatusUnknown Status = iota
|
||||||
StatusSuccess
|
StatusSuccess
|
||||||
StatusError
|
StatusError
|
||||||
|
StatusNextProtocol
|
||||||
)
|
)
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
Packet() *radius.Request
|
Packet() *radius.Request
|
||||||
|
|
||||||
ProtocolSettings() interface{}
|
ProtocolSettings() interface{}
|
||||||
GetProtocolState(def func(Context) interface{}) interface{}
|
GetProtocolState() interface{}
|
||||||
SetProtocolState(interface{})
|
SetProtocolState(interface{})
|
||||||
|
|
||||||
|
IsProtocolStart() bool
|
||||||
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
||||||
|
|
||||||
Log() *log.Entry
|
Log() *log.Entry
|
||||||
|
|||||||
@ -3,11 +3,9 @@ package protocol
|
|||||||
type Payload interface {
|
type Payload interface {
|
||||||
Decode(raw []byte) error
|
Decode(raw []byte) error
|
||||||
Encode() ([]byte, error)
|
Encode() ([]byte, error)
|
||||||
|
Handle(ctx Context) Payload
|
||||||
|
Type() Type
|
||||||
|
Offerable() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Type uint8
|
type Type uint8
|
||||||
|
|
||||||
const (
|
|
||||||
TypeIdentity Type = 1
|
|
||||||
TypeMD5Challenge Type = 4
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
package eap
|
package eap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ProtocolConstructor func() protocol.Payload
|
||||||
|
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
ProtocolsToOffer []protocol.Type
|
Protocols []ProtocolConstructor
|
||||||
|
ProtocolPriority []protocol.Type
|
||||||
ProtocolSettings map[protocol.Type]interface{}
|
ProtocolSettings map[protocol.Type]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -18,13 +22,23 @@ type StateManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type State struct {
|
type State struct {
|
||||||
ChallengesToOffer []protocol.Type
|
Protocols []ProtocolConstructor
|
||||||
TypeState map[protocol.Type]any
|
ProtocolIndex int
|
||||||
|
ProtocolPriority []protocol.Type
|
||||||
|
TypeState map[protocol.Type]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *State) GetNextProtocol() (protocol.Type, error) {
|
||||||
|
if st.ProtocolIndex >= len(st.ProtocolPriority) {
|
||||||
|
return protocol.Type(0), errors.New("no more protocols to offer")
|
||||||
|
}
|
||||||
|
return st.ProtocolPriority[st.ProtocolIndex], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func BlankState(settings Settings) *State {
|
func BlankState(settings Settings) *State {
|
||||||
return &State{
|
return &State{
|
||||||
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
|
Protocols: slices.Clone(settings.Protocols),
|
||||||
TypeState: map[protocol.Type]any{},
|
ProtocolPriority: slices.Clone(settings.ProtocolPriority),
|
||||||
|
TypeState: map[protocol.Type]any{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,6 +21,10 @@ const staleConnectionTimeout = 10
|
|||||||
|
|
||||||
const TypeTLS protocol.Type = 13
|
const TypeTLS protocol.Type = 13
|
||||||
|
|
||||||
|
func Protocol() protocol.Payload {
|
||||||
|
return &Payload{}
|
||||||
|
}
|
||||||
|
|
||||||
type Payload struct {
|
type Payload struct {
|
||||||
Flags Flag
|
Flags Flag
|
||||||
Length uint32
|
Length uint32
|
||||||
@ -29,6 +33,14 @@ type Payload struct {
|
|||||||
st *State
|
st *State
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Payload) Type() protocol.Type {
|
||||||
|
return TypeTLS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Payload) Offerable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Payload) Decode(raw []byte) error {
|
func (p *Payload) Decode(raw []byte) error {
|
||||||
p.Flags = Flag(raw[0])
|
p.Flags = Flag(raw[0])
|
||||||
raw = raw[1:]
|
raw = raw[1:]
|
||||||
@ -65,15 +77,16 @@ func (p *Payload) Encode() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||||
p.st = ctx.GetProtocolState(NewState).(*State)
|
defer func() {
|
||||||
defer ctx.SetProtocolState(p.st)
|
ctx.SetProtocolState(p.st)
|
||||||
if !p.st.HasStarted {
|
}()
|
||||||
ctx.Log().Debug("TLS: handshake starting")
|
if ctx.IsProtocolStart() {
|
||||||
p.st.HasStarted = true
|
p.st = NewState(ctx).(*State)
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: FlagTLSStart,
|
Flags: FlagTLSStart,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
p.st = ctx.GetProtocolState().(*State)
|
||||||
|
|
||||||
if p.st.TLS == nil {
|
if p.st.TLS == nil {
|
||||||
p.tlsInit(ctx)
|
p.tlsInit(ctx)
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type State struct {
|
type State struct {
|
||||||
HasStarted bool
|
|
||||||
RemainingChunks [][]byte
|
RemainingChunks [][]byte
|
||||||
HandshakeDone bool
|
HandshakeDone bool
|
||||||
FinalStatus protocol.Status
|
FinalStatus protocol.Status
|
||||||
|
|||||||
@ -12,6 +12,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/flow"
|
"goauthentik.io/internal/outpost/flow"
|
||||||
"goauthentik.io/internal/outpost/radius/eap"
|
"goauthentik.io/internal/outpost/radius/eap"
|
||||||
|
"goauthentik.io/internal/outpost/radius/eap/identity"
|
||||||
|
"goauthentik.io/internal/outpost/radius/eap/legacy_nak"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||||
"goauthentik.io/internal/outpost/radius/metrics"
|
"goauthentik.io/internal/outpost/radius/metrics"
|
||||||
@ -111,12 +113,12 @@ func (rs *RadiusServer) Handle_AccessRequest_PAP(w radius.ResponseWriter, r *Rad
|
|||||||
|
|
||||||
func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *RadiusRequest) {
|
func (rs *RadiusServer) Handle_AccessRequest_EAP(w radius.ResponseWriter, r *RadiusRequest) {
|
||||||
er := rfc2869.EAPMessage_Get(r.Packet)
|
er := rfc2869.EAPMessage_Get(r.Packet)
|
||||||
ep, err := eap.Decode(er)
|
ep, err := eap.Decode(r.pi, er)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rs.log.WithError(err).Warning("failed to parse EAP packet")
|
rs.log.WithError(err).Warning("failed to parse EAP packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ep.Handle(r.pi, w, r.Request)
|
ep.HandleRadiusPacket(w, r.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pi *ProviderInstance) GetEAPState(key string) *eap.State {
|
func (pi *ProviderInstance) GetEAPState(key string) *eap.State {
|
||||||
@ -128,22 +130,30 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
||||||
|
protocols := []eap.ProtocolConstructor{
|
||||||
|
identity.Protocol,
|
||||||
|
legacy_nak.Protocol,
|
||||||
|
}
|
||||||
|
|
||||||
certId := pi.certId
|
certId := pi.certId
|
||||||
if certId == "" {
|
if certId == "" {
|
||||||
return eap.Settings{
|
return eap.Settings{
|
||||||
ProtocolsToOffer: []protocol.Type{},
|
Protocols: protocols,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cert := pi.s.cryptoStore.Get(certId)
|
cert := pi.s.cryptoStore.Get(certId)
|
||||||
if cert == nil {
|
if cert == nil {
|
||||||
return eap.Settings{
|
return eap.Settings{
|
||||||
ProtocolsToOffer: []protocol.Type{},
|
Protocols: protocols,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return eap.Settings{
|
return eap.Settings{
|
||||||
ProtocolsToOffer: []protocol.Type{tls.TypeTLS},
|
Protocols: append(protocols, tls.Protocol),
|
||||||
|
ProtocolPriority: []protocol.Type{
|
||||||
|
tls.TypeTLS,
|
||||||
|
},
|
||||||
ProtocolSettings: map[protocol.Type]interface{}{
|
ProtocolSettings: map[protocol.Type]interface{}{
|
||||||
tls.TypeTLS: tls.Settings{
|
tls.TypeTLS: tls.Settings{
|
||||||
Config: &ttls.Config{
|
Config: &ttls.Config{
|
||||||
|
|||||||
Reference in New Issue
Block a user