ok this works kinda
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -7,24 +7,31 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type context struct {
|
type context struct {
|
||||||
state interface{}
|
state interface{}
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
|
settings interface{}
|
||||||
|
endStatus protocol.Status
|
||||||
|
endModifier func(p *radius.Packet) *radius.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx context) ProtocolSettings() interface{} {
|
func (ctx context) ProtocolSettings() interface{} {
|
||||||
return nil
|
return ctx.settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
|
func (ctx *context) GetProtocolState(def func(protocol.Context) interface{}) interface{} {
|
||||||
|
if ctx.state == nil {
|
||||||
|
ctx.state = def(ctx)
|
||||||
|
}
|
||||||
return ctx.state
|
return ctx.state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx context) SetProtocolState(st interface{}) {
|
func (ctx *context) SetProtocolState(st interface{}) {
|
||||||
ctx.state = st
|
ctx.state = st
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx context) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) {
|
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
||||||
|
ctx.endStatus = st
|
||||||
|
ctx.endModifier = mf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx context) Log() *log.Entry {
|
func (ctx context) Log() *log.Entry {
|
||||||
|
@ -29,9 +29,10 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
|||||||
}
|
}
|
||||||
nextChallengeToOffer := st.ChallengesToOffer[0]
|
nextChallengeToOffer := st.ChallengesToOffer[0]
|
||||||
|
|
||||||
ctx := context{
|
ctx := &context{
|
||||||
state: st.TypeState[nextChallengeToOffer],
|
state: st.TypeState[nextChallengeToOffer],
|
||||||
log: log.WithField("type", nextChallengeToOffer),
|
log: log.WithField("type", nextChallengeToOffer),
|
||||||
|
settings: stm.GetEAPSettings().ProtocolSettings[nextChallengeToOffer],
|
||||||
}
|
}
|
||||||
|
|
||||||
res := p.GetChallengeForType(ctx, nextChallengeToOffer)
|
res := p.GetChallengeForType(ctx, nextChallengeToOffer)
|
||||||
@ -39,11 +40,24 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
|||||||
stm.SetEAPState(rst, st)
|
stm.SetEAPState(rst, st)
|
||||||
|
|
||||||
rres := r.Response(radius.CodeAccessChallenge)
|
rres := r.Response(radius.CodeAccessChallenge)
|
||||||
if p, ok := res.Payload.(protocol.EmptyPayload); ok {
|
switch ctx.endStatus {
|
||||||
// TODO: This is a bit hacky here
|
case protocol.StatusSuccess:
|
||||||
res.code = CodeSuccess
|
res.code = CodeSuccess
|
||||||
res.id -= 1
|
res.id -= 1
|
||||||
rres = p.ModifyPacket(rres)
|
rres = ctx.endModifier(rres)
|
||||||
|
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
||||||
|
if len(st.ChallengesToOffer) < 1 {
|
||||||
|
rres.Code = radius.CodeAccessAccept
|
||||||
|
}
|
||||||
|
case protocol.StatusError:
|
||||||
|
res.code = CodeFailure
|
||||||
|
res.id -= 1
|
||||||
|
st.ChallengesToOffer = st.ChallengesToOffer[1:]
|
||||||
|
rres = ctx.endModifier(rres)
|
||||||
|
if len(st.ChallengesToOffer) < 1 {
|
||||||
|
rres.Code = radius.CodeAccessReject
|
||||||
|
}
|
||||||
|
case protocol.StatusUnknown:
|
||||||
}
|
}
|
||||||
rfc2865.State_SetString(rres, rst)
|
rfc2865.State_SetString(rres, rst)
|
||||||
eapEncoded, err := res.Encode()
|
eapEncoded, err := res.Encode()
|
||||||
@ -59,7 +73,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
|
func (p *Packet) GetChallengeForType(ctx *context, t protocol.Type) *Packet {
|
||||||
res := &Packet{
|
res := &Packet{
|
||||||
code: CodeRequest,
|
code: CodeRequest,
|
||||||
id: p.id + 1,
|
id: p.id + 1,
|
||||||
@ -67,7 +81,7 @@ func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
|
|||||||
}
|
}
|
||||||
var payload any
|
var payload any
|
||||||
switch t {
|
switch t {
|
||||||
case TypeTLS:
|
case tls.TypeTLS:
|
||||||
// TODO: rewrite this
|
// TODO: rewrite this
|
||||||
if _, ok := p.Payload.(*tls.Payload); !ok {
|
if _, ok := p.Payload.(*tls.Payload); !ok {
|
||||||
p.Payload = &tls.Payload{}
|
p.Payload = &tls.Payload{}
|
||||||
@ -76,8 +90,9 @@ func (p *Packet) GetChallengeForType(ctx context, t Type) *Packet {
|
|||||||
// this
|
// this
|
||||||
payload = p.Payload.(*tls.Payload).Handle(ctx)
|
payload = p.Payload.(*tls.Payload).Handle(ctx)
|
||||||
}
|
}
|
||||||
// st.TypeState[t] = tst
|
if payload != nil {
|
||||||
res.Payload = payload.(protocol.Payload)
|
res.Payload = payload.(protocol.Payload)
|
||||||
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,32 +16,25 @@ const (
|
|||||||
CodeRequest Code = 1
|
CodeRequest Code = 1
|
||||||
CodeResponse Code = 2
|
CodeResponse Code = 2
|
||||||
CodeSuccess Code = 3
|
CodeSuccess Code = 3
|
||||||
)
|
CodeFailure Code = 4
|
||||||
|
|
||||||
type Type uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeIdentity Type = 1
|
|
||||||
TypeMD5Challenge Type = 4
|
|
||||||
TypeTLS Type = 13
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
code Code
|
code Code
|
||||||
id uint8
|
id uint8
|
||||||
length uint16
|
length uint16
|
||||||
msgType Type
|
msgType protocol.Type
|
||||||
rawPayload []byte
|
rawPayload []byte
|
||||||
Payload protocol.Payload
|
Payload protocol.Payload
|
||||||
}
|
}
|
||||||
|
|
||||||
type PayloadWriter struct{}
|
type PayloadWriter struct{}
|
||||||
|
|
||||||
func emptyPayload(t Type) protocol.Payload {
|
func emptyPayload(t protocol.Type) protocol.Payload {
|
||||||
switch t {
|
switch t {
|
||||||
case TypeIdentity:
|
case protocol.TypeIdentity:
|
||||||
return &IdentityPayload{}
|
return &IdentityPayload{}
|
||||||
case TypeTLS:
|
case tls.TypeTLS:
|
||||||
return &tls.Payload{}
|
return &tls.Payload{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -56,7 +49,7 @@ func Decode(raw []byte) (*Packet, error) {
|
|||||||
return nil, errors.New("mismatched packet length")
|
return nil, errors.New("mismatched packet length")
|
||||||
}
|
}
|
||||||
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
|
if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) {
|
||||||
packet.msgType = Type(raw[4])
|
packet.msgType = protocol.Type(raw[4])
|
||||||
}
|
}
|
||||||
packet.Payload = emptyPayload(packet.msgType)
|
packet.Payload = emptyPayload(packet.msgType)
|
||||||
packet.rawPayload = raw[5:]
|
packet.rawPayload = raw[5:]
|
||||||
@ -73,14 +66,16 @@ func (p *Packet) Encode() ([]byte, error) {
|
|||||||
buff[0] = uint8(p.code)
|
buff[0] = uint8(p.code)
|
||||||
buff[1] = uint8(p.id)
|
buff[1] = uint8(p.id)
|
||||||
|
|
||||||
payloadBuffer, err := p.Payload.Encode()
|
if p.Payload != nil {
|
||||||
if err != nil {
|
payloadBuffer, err := p.Payload.Encode()
|
||||||
return buff, err
|
if err != nil {
|
||||||
|
return buff, err
|
||||||
|
}
|
||||||
|
if p.code == CodeRequest || p.code == CodeResponse {
|
||||||
|
buff = append(buff, uint8(p.msgType))
|
||||||
|
}
|
||||||
|
buff = append(buff, payloadBuffer...)
|
||||||
}
|
}
|
||||||
if p.code == CodeRequest || p.code == CodeResponse {
|
|
||||||
buff = append(buff, uint8(p.msgType))
|
|
||||||
}
|
|
||||||
buff = append(buff, payloadBuffer...)
|
|
||||||
binary.BigEndian.PutUint16(buff[2:], uint16(len(buff)))
|
binary.BigEndian.PutUint16(buff[2:], uint16(len(buff)))
|
||||||
return buff, nil
|
return buff, nil
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,14 @@ import (
|
|||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Status int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusUnknown Status = iota
|
||||||
|
StatusSuccess
|
||||||
|
StatusError
|
||||||
|
)
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
// GlobalState()
|
// GlobalState()
|
||||||
|
|
||||||
@ -12,7 +20,7 @@ type Context interface {
|
|||||||
GetProtocolState(def func(Context) interface{}) interface{}
|
GetProtocolState(def func(Context) interface{}) interface{}
|
||||||
SetProtocolState(interface{})
|
SetProtocolState(interface{})
|
||||||
|
|
||||||
EndInnerProtocol(func(p *radius.Packet) *radius.Packet)
|
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
||||||
|
|
||||||
Log() *log.Entry
|
Log() *log.Entry
|
||||||
}
|
}
|
||||||
|
@ -4,3 +4,10 @@ type Payload interface {
|
|||||||
Decode(raw []byte) error
|
Decode(raw []byte) error
|
||||||
Encode() ([]byte, error)
|
Encode() ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Type uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeIdentity Type = 1
|
||||||
|
TypeMD5Challenge Type = 4
|
||||||
|
)
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
package eap
|
package eap
|
||||||
|
|
||||||
import "slices"
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
ProtocolsToOffer []Type
|
ProtocolsToOffer []protocol.Type
|
||||||
ProtocolSettings map[Type]interface{}
|
ProtocolSettings map[protocol.Type]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateManager interface {
|
type StateManager interface {
|
||||||
@ -14,13 +18,13 @@ type StateManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type State struct {
|
type State struct {
|
||||||
ChallengesToOffer []Type
|
ChallengesToOffer []protocol.Type
|
||||||
TypeState map[Type]any
|
TypeState map[protocol.Type]any
|
||||||
}
|
}
|
||||||
|
|
||||||
func BlankState(settings Settings) *State {
|
func BlankState(settings Settings) *State {
|
||||||
return &State{
|
return &State{
|
||||||
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
|
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
|
||||||
TypeState: map[Type]any{},
|
TypeState: map[protocol.Type]any{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,8 @@ import (
|
|||||||
const maxChunkSize = 1000
|
const maxChunkSize = 1000
|
||||||
const staleConnectionTimeout = 10
|
const staleConnectionTimeout = 10
|
||||||
|
|
||||||
|
const TypeTLS protocol.Type = 13
|
||||||
|
|
||||||
type Payload struct {
|
type Payload struct {
|
||||||
Flags Flag
|
Flags Flag
|
||||||
Length uint32
|
Length uint32
|
||||||
@ -102,8 +104,7 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
|||||||
}
|
}
|
||||||
if p.st.Conn.writer.Len() == 0 && p.st.HandshakeDone {
|
if p.st.Conn.writer.Len() == 0 && p.st.HandshakeDone {
|
||||||
defer p.st.ContextCancel()
|
defer p.st.ContextCancel()
|
||||||
ctx.EndInnerProtocol(func(r *radius.Packet) *radius.Packet {
|
ctx.EndInnerProtocol(protocol.StatusSuccess, func(r *radius.Packet) *radius.Packet {
|
||||||
r.Code = radius.CodeAccessAccept
|
|
||||||
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
|
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
|
||||||
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
|
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
|
||||||
return r
|
return r
|
||||||
@ -128,7 +129,9 @@ func (p *Payload) tlsInit(ctx protocol.Context) {
|
|||||||
err := p.st.TLS.HandshakeContext(p.st.Context)
|
err := p.st.TLS.HandshakeContext(p.st.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Debug("TLS: Handshake error")
|
log.WithError(err).Debug("TLS: Handshake error")
|
||||||
// TODO: Send a NAK to the client
|
ctx.EndInnerProtocol(protocol.StatusError, func(p *radius.Packet) *radius.Packet {
|
||||||
|
return p
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debug("TLS: handshake done")
|
log.Debug("TLS: handshake done")
|
||||||
@ -150,7 +153,7 @@ func (p *Payload) tlsHandshakeFinished() {
|
|||||||
case tls.VersionTLS13:
|
case tls.VersionTLS13:
|
||||||
log.Debugf("TLS: Version %d (1.3)", cs.Version)
|
log.Debugf("TLS: Version %d (1.3)", cs.Version)
|
||||||
label = "EXPORTER_EAP_TLS_Key_Material"
|
label = "EXPORTER_EAP_TLS_Key_Material"
|
||||||
context = []byte{13}
|
context = []byte{byte(TypeTLS)}
|
||||||
}
|
}
|
||||||
ksm, err := cs.ExportKeyingMaterial(label, context, 64+64)
|
ksm, err := cs.ExportKeyingMaterial(label, context, 64+64)
|
||||||
log.Debugf("TLS: ksm % x %v", ksm, err)
|
log.Debugf("TLS: ksm % x %v", ksm, err)
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/flow"
|
"goauthentik.io/internal/outpost/flow"
|
||||||
"goauthentik.io/internal/outpost/radius/eap"
|
"goauthentik.io/internal/outpost/radius/eap"
|
||||||
|
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||||
"goauthentik.io/internal/outpost/radius/metrics"
|
"goauthentik.io/internal/outpost/radius/metrics"
|
||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
@ -134,9 +135,9 @@ func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return eap.Settings{
|
return eap.Settings{
|
||||||
ProtocolsToOffer: []eap.Type{eap.TypeTLS},
|
ProtocolsToOffer: []protocol.Type{tls.TypeTLS},
|
||||||
ProtocolSettings: map[eap.Type]interface{}{
|
ProtocolSettings: map[protocol.Type]interface{}{
|
||||||
eap.TypeTLS: tls.Settings{
|
tls.TypeTLS: tls.Settings{
|
||||||
Config: &ttls.Config{
|
Config: &ttls.Config{
|
||||||
Certificates: []ttls.Certificate{cert},
|
Certificates: []ttls.Certificate{cert},
|
||||||
ClientAuth: ttls.RequireAnyClientCert,
|
ClientAuth: ttls.RequireAnyClientCert,
|
||||||
|
Reference in New Issue
Block a user