ok this works kinda
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -7,24 +7,31 @@ import (
|
||||
)
|
||||
|
||||
type context struct {
|
||||
state interface{}
|
||||
log *log.Entry
|
||||
state interface{}
|
||||
log *log.Entry
|
||||
settings interface{}
|
||||
endStatus protocol.Status
|
||||
endModifier func(p *radius.Packet) *radius.Packet
|
||||
}
|
||||
|
||||
func (ctx context) ProtocolSettings() interface{} {
|
||||
return nil
|
||||
return ctx.settings
|
||||
}
|
||||
|
||||
func (ctx context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
|
||||
func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
|
||||
if ctx.state == nil {
|
||||
ctx.state = def(ctx)
|
||||
}
|
||||
return ctx.state
|
||||
}
|
||||
|
||||
func (ctx context) SetProtocolState(st interface{}) {
|
||||
func (ctx *context) SetProtocolState(st interface{}) {
|
||||
ctx.state = st
|
||||
}
|
||||
|
||||
func (ctx context) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) {
|
||||
|
||||
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
||||
ctx.endStatus = st
|
||||
ctx.endModifier = mf
|
||||
}
|
||||
|
||||
func (ctx context) Log() *log.Entry {
|
||||
|
@ -29,9 +29,10 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
||||
}
|
||||
nextChallengeToOffer := st.ChallengesToOffer[0]
|
||||
|
||||
ctx := context{
|
||||
state: st.TypeState[nextChallengeToOffer],
|
||||
log: log.WithField("type", nextChallengeToOffer),
|
||||
ctx := &context{
|
||||
state: st.TypeState[nextChallengeToOffer],
|
||||
log: log.WithField("type", nextChallengeToOffer),
|
||||
settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer],
|
||||
}
|
||||
|
||||
res := p.GetChallengeForType(ctx, nextChallengeToOffer)
|
||||
@ -39,11 +40,24 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
||||
stm.SetEAPState(rst, st)
|
||||
|
||||
rres := r.Response(radius.CodeAccessChallenge)
|
||||
if p, ok := res.Payload.(protocol.EmptyPayload); ok {
|
||||
// TODO: This is a bit hacky here
|
||||
switch ctx.endStatus {
|
||||
case protocol.StatusSuccess:
|
||||
res.code = CodeSuccess
|
||||
res.id -= 1
|
||||
rres = p.ModifyPacket(rres)
|
||||
rres = ctx.endModifier(rres)
|
||||
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
||||
if len(st.ChallengesToOffer) < 1 {
|
||||
rres.Code = radius.CodeAccessAccept
|
||||
}
|
||||
case protocol.StatusError:
|
||||
res.code = CodeFailure
|
||||
res.id -= 1
|
||||
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
||||
rres = ctx.endModifier(rres)
|
||||
if len(st.ChallengesToOffer) < 1 {
|
||||
rres.Code = radius.CodeAccessReject
|
||||
}
|
||||
case protocol.StatusUnknown:
|
||||
}
|
||||
rfc2865.State_SetString(rres, rst)
|
||||
eapEncoded, err := res.Encode()
|
||||
@ -59,7 +73,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
|
||||
func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet {
|
||||
res := &Packet{
|
||||
code: CodeRequest,
|
||||
id: p.id + 1,
|
||||
@ -67,7 +81,7 @@ func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
|
||||
}
|
||||
var payload any
|
||||
switch t {
|
||||
case TypeTLS:
|
||||
case tls.TypeTLS:
|
||||
// TODO: rewrite this
|
||||
if _, ok := p.Payload.(*tls.Payload); !ok {
|
||||
p.Payload = &tls.Payload{}
|
||||
@ -76,8 +90,9 @@ func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
|
||||
// this
|
||||
payload = p.Payload.(*tls.Payload).Handle(ctx)
|
||||
}
|
||||
// st.TypeState[t] = tst
|
||||
res.Payload = payload.(protocol.Payload)
|
||||
if payload != nil {
|
||||
res.Payload = payload.(protocol.Payload)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
|
@ -16,32 +16,25 @@ const (
|
||||
CodeRequest Code = 1
|
||||
CodeResponse Code = 2
|
||||
CodeSuccess Code = 3
|
||||
)
|
||||
|
||||
type Type uint8
|
||||
|
||||
const (
|
||||
TypeIdentity Type = 1
|
||||
TypeMD5Challenge Type = 4
|
||||
TypeTLS Type = 13
|
||||
CodeFailure Code = 4
|
||||
)
|
||||
|
||||
type Packet struct {
|
||||
code Code
|
||||
id uint8
|
||||
length uint16
|
||||
msgType Type
|
||||
msgType protocol.Type
|
||||
rawPayload []byte
|
||||
Payload protocol.Payload
|
||||
}
|
||||
|
||||
type PayloadWriter struct{}
|
||||
|
||||
func emptyPayload(t Type) protocol.Payload {
|
||||
func emptyPayload(t protocol.Type) protocol.Payload {
|
||||
switch t {
|
||||
case TypeIdentity:
|
||||
case protocol.TypeIdentity:
|
||||
return &IdentityPayload{}
|
||||
case TypeTLS:
|
||||
case tls.TypeTLS:
|
||||
return &tls.Payload{}
|
||||
}
|
||||
return nil
|
||||
@ -56,7 +49,7 @@ func Decode(raw []byte) (*Packet, error) {
|
||||
return nil, errors.New("mismatched packet length")
|
||||
}
|
||||
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
|
||||
packet.msgType = Type(raw[4])
|
||||
packet.msgType = protocol.Type(raw[4])
|
||||
}
|
||||
packet.Payload = emptyPayload(packet.msgType)
|
||||
packet.rawPayload = raw[5:]
|
||||
@ -73,14 +66,16 @@ func (p *Packet) Encode() ([]byte, error) {
|
||||
buff[0] = uint8(p.code)
|
||||
buff[1] = uint8(p.id)
|
||||
|
||||
payloadBuffer, err := p.Payload.Encode()
|
||||
if err != nil {
|
||||
return buff, err
|
||||
if p.Payload != nil {
|
||||
payloadBuffer, err := p.Payload.Encode()
|
||||
if err != nil {
|
||||
return buff, err
|
||||
}
|
||||
if p.code == CodeRequest || p.code == CodeResponse {
|
||||
buff = append(buff, uint8(p.msgType))
|
||||
}
|
||||
buff = append(buff, payloadBuffer...)
|
||||
}
|
||||
if p.code == CodeRequest || p.code == CodeResponse {
|
||||
buff = append(buff, uint8(p.msgType))
|
||||
}
|
||||
buff = append(buff, payloadBuffer...)
|
||||
binary.BigEndian.PutUint16(buff[2:], uint16(len(buff)))
|
||||
return buff, nil
|
||||
}
|
||||
|
@ -5,6 +5,14 @@ import (
|
||||
"layeh.com/radius"
|
||||
)
|
||||
|
||||
type Status int
|
||||
|
||||
const (
|
||||
StatusUnknown Status = iota
|
||||
StatusSuccess
|
||||
StatusError
|
||||
)
|
||||
|
||||
type Context interface {
|
||||
// GlobalState()
|
||||
|
||||
@ -12,7 +20,7 @@ type Context interface {
|
||||
GetProtocolState(def func(Context) interface{}) interface{}
|
||||
SetProtocolState(interface{})
|
||||
|
||||
EndInnerProtocol(func(p *radius.Packet) *radius.Packet)
|
||||
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
||||
|
||||
Log() *log.Entry
|
||||
}
|
||||
|
@ -4,3 +4,10 @@ type Payload interface {
|
||||
Decode(raw []byte) error
|
||||
Encode() ([]byte, error)
|
||||
}
|
||||
|
||||
type Type uint8
|
||||
|
||||
const (
|
||||
TypeIdentity Type = 1
|
||||
TypeMD5Challenge Type = 4
|
||||
)
|
||||
|
@ -1,10 +1,14 @@
|
||||
package eap
|
||||
|
||||
import "slices"
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
ProtocolsToOffer []Type
|
||||
ProtocolSettings map[Type]interface{}
|
||||
ProtocolsToOffer []protocol.Type
|
||||
ProtocolSettings map[protocol.Type]interface{}
|
||||
}
|
||||
|
||||
type StateManager interface {
|
||||
@ -14,13 +18,13 @@ type StateManager interface {
|
||||
}
|
||||
|
||||
type State struct {
|
||||
ChallengesToOffer []Type
|
||||
TypeState map[Type]any
|
||||
ChallengesToOffer []protocol.Type
|
||||
TypeState map[protocol.Type]any
|
||||
}
|
||||
|
||||
func BlankState(settings Settings) *State {
|
||||
return &State{
|
||||
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
|
||||
TypeState: map[Type]any{},
|
||||
TypeState: map[protocol.Type]any{},
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,8 @@ import (
|
||||
const maxChunkSize = 1000
|
||||
const staleConnectionTimeout = 10
|
||||
|
||||
const TypeTLS protocol.Type = 13
|
||||
|
||||
type Payload struct {
|
||||
Flags Flag
|
||||
Length uint32
|
||||
@ -102,8 +104,7 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
}
|
||||
if p.st.Conn.writer.Len() == 0 && p.st.HandshakeDone {
|
||||
defer p.st.ContextCancel()
|
||||
ctx.EndInnerProtocol(func(r *radius.Packet) *radius.Packet {
|
||||
r.Code = radius.CodeAccessAccept
|
||||
ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet {
|
||||
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
|
||||
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
|
||||
return r
|
||||
@ -128,7 +129,9 @@ func (p *Payload) tlsInit(ctx protocol.Context) {
|
||||
err := p.st.TLS.HandshakeContext(p.st.Context)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("TLS: Handshake error")
|
||||
// TODO: Send a NAK to the client
|
||||
ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet {
|
||||
return p
|
||||
})
|
||||
return
|
||||
}
|
||||
log.Debug("TLS: handshake done")
|
||||
@ -150,7 +153,7 @@ func (p *Payload) tlsHandshakeFinished() {
|
||||
case tls.VersionTLS13:
|
||||
log.Debugf("TLS: Version %d (1.3)", cs.Version)
|
||||
label = "EXPORTER_EAP_TLS_Key_Material"
|
||||
context = []byte{13}
|
||||
context = []byte{byte(TypeTLS)}
|
||||
}
|
||||
ksm, err := cs.ExportKeyingMaterial(label, context, 64+64)
|
||||
log.Debugf("TLS: ksm % x %v", ksm, err)
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/flow"
|
||||
"goauthentik.io/internal/outpost/radius/eap"
|
||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||
"goauthentik.io/internal/outpost/radius/metrics"
|
||||
"layeh.com/radius"
|
||||
@ -134,9 +135,9 @@ func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
||||
}
|
||||
|
||||
return eap.Settings{
|
||||
ProtocolsToOffer: []eap.Type{eap.TypeTLS},
|
||||
ProtocolSettings: map[eap.Type]interface{}{
|
||||
eap.TypeTLS: tls.Settings{
|
||||
ProtocolsToOffer: []protocol.Type{tls.TypeTLS},
|
||||
ProtocolSettings: map[protocol.Type]interface{}{
|
||||
tls.TypeTLS: tls.Settings{
|
||||
Config: &ttls.Config{
|
||||
Certificates: []ttls.Certificate{cert},
|
||||
ClientAuth: ttls.RequireAnyClientCert,
|
||||
|
Reference in New Issue
Block a user