Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-16 14:23:11 +02:00
parent 855afa7b9f
commit ac88784089
8 changed files with 151 additions and 93 deletions

View File

@ -0,0 +1,29 @@
package eap
import (
"github.com/sirupsen/logrus"
"layeh.com/radius"
)
type context[TState any, TSettings any] struct {
}
func (ctx context[TState, TSettings]) ProtocolSettings() TSettings {
return 0
}
func (ctx context[TState, TSettings]) GetProtocolState(def func(context[TState, TSettings]) TState) TState {
return nil
}
func (ctx context[TState, TSettings]) SetProtocolState(TState) {
}
func (ctx context[TState, TSettings]) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) {
}
func (ctx context[TState, TSettings]) Log() *logrus.Entry {
return nil
}

View File

@ -28,7 +28,10 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
panic("No more challenges") panic("No more challenges")
} }
nextChallengeToOffer := st.ChallengesToOffer[0] nextChallengeToOffer := st.ChallengesToOffer[0]
res, newState := p.GetChallengeForType(st, nextChallengeToOffer)
ctx := context{}
res, newState := p.GetChallengeForType(ctx, nextChallengeToOffer)
stm.SetEAPState(rst, newState) stm.SetEAPState(rst, newState)
rres := r.Response(radius.CodeAccessChallenge) rres := r.Response(radius.CodeAccessChallenge)
@ -52,21 +55,22 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
} }
} }
func (p *Packet) GetChallengeForType(st *State, t Type) (*Packet, *State) { func (p *Packet) GetChallengeForType(ctx context[any, any], t Type) *Packet {
res := &Packet{ res := &Packet{
code: CodeRequest, code: CodeRequest,
id: p.id + 1, id: p.id + 1,
msgType: t, msgType: t,
} }
var payload any var payload any
var tst any
switch t { switch t {
case TypeTLS: case TypeTLS:
// TODO: rewrite this
if _, ok := p.Payload.(*tls.Payload); !ok { if _, ok := p.Payload.(*tls.Payload); !ok {
p.Payload = &tls.Payload{} p.Payload = &tls.Payload{}
p.Payload.Decode(p.rawPayload) p.Payload.Decode(p.rawPayload)
} }
payload, tst = p.Payload.(*tls.Payload).Handle(st.TypeState[t]) // this
payload = p.Payload.(*tls.Payload).Handle(ctx)
} }
st.TypeState[t] = tst st.TypeState[t] = tst
res.Payload = payload.(protocol.Payload) res.Payload = payload.(protocol.Payload)

View File

@ -0,0 +1,18 @@
package protocol
import (
log "github.com/sirupsen/logrus"
"layeh.com/radius"
)
type Context[TState any, TSettings any] interface {
// GlobalState()
ProtocolSettings() TSettings
GetProtocolState(def func(Context[TState, TSettings]) TState) TState
SetProtocolState(TState)
EndInnerProtocol(func(p *radius.Packet) *radius.Packet)
Log() *log.Entry
}

View File

@ -3,8 +3,8 @@ package eap
import "slices" import "slices"
type Settings struct { type Settings struct {
ChallengesToOffer []Type ProtocolsToOffer []Type
ChallengeSettings map[Type]interface{} ProtocolSettings map[Type]interface{}
} }
type StateManager interface { type StateManager interface {
@ -20,7 +20,7 @@ type State struct {
func BlankState(settings Settings) *State { func BlankState(settings Settings) *State {
return &State{ return &State{
ChallengesToOffer: slices.Clone(settings.ChallengesToOffer), ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
TypeState: map[Type]any{}, TypeState: map[Type]any{},
} }
} }

View File

@ -12,31 +12,18 @@ import (
"goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/debug"
"goauthentik.io/internal/outpost/radius/eap/protocol" "goauthentik.io/internal/outpost/radius/eap/protocol"
"layeh.com/radius" "layeh.com/radius"
"layeh.com/radius/rfc2865"
"layeh.com/radius/vendors/microsoft" "layeh.com/radius/vendors/microsoft"
) )
const maxChunkSize = 1000 const maxChunkSize = 1000
const staleConnectionTimeout = 10 const staleConnectionTimeout = 10
var certs = []tls.Certificate{}
func init() {
// Testing
cert, err := tls.LoadX509KeyPair(
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
)
if err != nil {
panic(err)
}
certs = append(certs, cert)
}
type Payload struct { type Payload struct {
Flags Flag Flags Flag
Length uint32 Length uint32
Data []byte Data []byte
st *State
} }
func (p *Payload) Decode(raw []byte) error { func (p *Payload) Decode(raw []byte) error {
@ -74,92 +61,85 @@ func (p *Payload) Encode() ([]byte, error) {
return buff, nil return buff, nil
} }
func (p *Payload) Handle(stt any) (protocol.Payload, *State) { type tctx = protocol.Context[*State, Settings]
if stt == nil {
log.Debug("TLS: new state") func (p *Payload) Handle(ctx tctx) protocol.Payload {
stt = NewState() p.st = ctx.GetProtocolState(NewState)
} defer ctx.SetProtocolState(p.st)
st := stt.(*State) if !p.st.HasStarted {
if !st.HasStarted {
log.Debug("TLS: handshake starting") log.Debug("TLS: handshake starting")
st.HasStarted = true p.st.HasStarted = true
return &Payload{ return &Payload{
Flags: FlagTLSStart, Flags: FlagTLSStart,
}, st }
} }
if st.TLS == nil { if p.st.TLS == nil {
st = p.tlsInit(st) p.tlsInit(ctx)
} else if len(p.Data) > 0 { } else if len(p.Data) > 0 {
log.Debug("TLS: Updating buffer with new TLS data from packet") log.Debug("TLS: Updating buffer with new TLS data from packet")
if p.Flags&FlagLengthIncluded != 0 && st.Conn.expectedWriterByteCount == 0 { if p.Flags&FlagLengthIncluded != 0 && p.st.Conn.expectedWriterByteCount == 0 {
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length) log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
st.Conn.expectedWriterByteCount = int(p.Length) p.st.Conn.expectedWriterByteCount = int(p.Length)
} else if p.Flags&FlagLengthIncluded != 0 { } else if p.Flags&FlagLengthIncluded != 0 {
log.Debug("TLS: No length included, not buffering") log.Debug("TLS: No length included, not buffering")
st.Conn.expectedWriterByteCount = 0 p.st.Conn.expectedWriterByteCount = 0
} }
st.Conn.UpdateData(p.Data) p.st.Conn.UpdateData(p.Data)
if !st.Conn.NeedsMoreData() { if !p.st.Conn.NeedsMoreData() {
// Wait for outbound data to be available // Wait for outbound data to be available
st.Conn.OutboundData() p.st.Conn.OutboundData()
} }
} }
// If we need more data, send the client the go-ahead // If we need more data, send the client the go-ahead
if st.Conn.NeedsMoreData() { if p.st.Conn.NeedsMoreData() {
return &Payload{ return &Payload{
Flags: FlagNone, Flags: FlagNone,
Length: 0, Length: 0,
Data: []byte{}, Data: []byte{},
}, st }
} }
if st.HasMore() { if p.st.HasMore() {
return p.sendNextChunk(st) return p.sendNextChunk()
} }
if st.Conn.writer.Len() == 0 && st.HandshakeDone { if p.st.Conn.writer.Len() == 0 && p.st.HandshakeDone {
defer st.ContextCancel() defer p.st.ContextCancel()
return protocol.EmptyPayload{ ctx.EndInnerProtocol(func(r *radius.Packet) *radius.Packet {
ModifyPacket: func(p *radius.Packet) *radius.Packet { r.Code = radius.CodeAccessAccept
p.Code = radius.CodeAccessAccept microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
microsoft.MSMPPERecvKey_Set(p, st.MPPEKey[:32]) microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
microsoft.MSMPPESendKey_Set(p, st.MPPEKey[64:64+32]) return r
rfc2865.UserName_SetString(p, "foo") })
rfc2865.FramedMTU_Set(p, rfc2865.FramedMTU(1400)) return nil
return p
},
}, st
} }
return p.startChunkedTransfer(st.Conn.OutboundData(), st) return p.startChunkedTransfer(p.st.Conn.OutboundData())
} }
func (p *Payload) tlsInit(st *State) *State { func (p *Payload) tlsInit(ctx tctx) {
log.Debug("TLS: no TLS connection in state yet, starting connection") log.Debug("TLS: no TLS connection in state yet, starting connection")
st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
st.Conn = NewBuffConn(p.Data, st.Context) p.st.Conn = NewBuffConn(p.Data, p.st.Context)
st.TLS = tls.Server(st.Conn, &tls.Config{ cfg := ctx.ProtocolSettings().Config.Clone()
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
log.Debugf("TLS: ClientHello: %+v\n", ch) log.Debugf("TLS: ClientHello: %+v\n", chi)
st.ClientHello = ch p.st.ClientHello = chi
return nil, nil return nil, nil
}, }
ClientAuth: tls.RequireAnyClientCert, p.st.TLS = tls.Server(p.st.Conn, cfg)
Certificates: certs,
})
go func() { go func() {
err := st.TLS.HandshakeContext(st.Context) err := p.st.TLS.HandshakeContext(p.st.Context)
if err != nil { if err != nil {
log.WithError(err).Debug("TLS: Handshake error") log.WithError(err).Debug("TLS: Handshake error")
// TODO: Send a NAK to the client // TODO: Send a NAK to the client
return return
} }
log.Debug("TLS: handshake done") log.Debug("TLS: handshake done")
p.tlsHandshakeFinished(st) p.tlsHandshakeFinished()
}() }()
return st
} }
func (p *Payload) tlsHandshakeFinished(st *State) { func (p *Payload) tlsHandshakeFinished() {
cs := st.TLS.ConnectionState() cs := p.st.TLS.ConnectionState()
label := "client EAP encryption" label := "client EAP encryption"
var context []byte var context []byte
switch cs.Version { switch cs.Version {
@ -176,46 +156,46 @@ func (p *Payload) tlsHandshakeFinished(st *State) {
} }
ksm, err := cs.ExportKeyingMaterial(label, context, 64+64) ksm, err := cs.ExportKeyingMaterial(label, context, 64+64)
log.Debugf("TLS: ksm % x %v", ksm, err) log.Debugf("TLS: ksm % x %v", ksm, err)
st.MPPEKey = ksm p.st.MPPEKey = ksm
st.HandshakeDone = true p.st.HandshakeDone = true
} }
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) { func (p *Payload) startChunkedTransfer(data []byte) *Payload {
if len(data) > maxChunkSize { if len(data) > maxChunkSize {
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked") log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...) p.st.RemainingChunks = append(p.st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...)
st.TotalPayloadSize = len(data) p.st.TotalPayloadSize = len(data)
return p.sendNextChunk(st) return p.sendNextChunk()
} }
log.WithField("length", len(data)).Debug("TLS: Sending data un-chunked") log.WithField("length", len(data)).Debug("TLS: Sending data un-chunked")
st.Conn.writer.Reset() p.st.Conn.writer.Reset()
return &Payload{ return &Payload{
Flags: FlagLengthIncluded, Flags: FlagLengthIncluded,
Length: uint32(len(data)), Length: uint32(len(data)),
Data: data, Data: data,
}, st }
} }
func (p *Payload) sendNextChunk(st *State) (*Payload, *State) { func (p *Payload) sendNextChunk() *Payload {
nextChunk := st.RemainingChunks[0] nextChunk := p.st.RemainingChunks[0]
log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk") log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk")
st.RemainingChunks = st.RemainingChunks[1:] p.st.RemainingChunks = p.st.RemainingChunks[1:]
flags := FlagLengthIncluded flags := FlagLengthIncluded
if st.HasMore() { if p.st.HasMore() {
log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left") log.WithField("chunks", len(p.st.RemainingChunks)).Debug("TLS: More chunks left")
flags += FlagMoreFragments flags += FlagMoreFragments
} else { } else {
// Last chunk, reset the connection buffers and pending payload size // Last chunk, reset the connection buffers and pending payload size
defer func() { defer func() {
log.Debug("TLS: Sent last chunk") log.Debug("TLS: Sent last chunk")
st.Conn.writer.Reset() p.st.Conn.writer.Reset()
st.TotalPayloadSize = 0 p.st.TotalPayloadSize = 0
}() }()
} }
log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size") log.WithField("length", p.st.TotalPayloadSize).Debug("TLS: Total payload size")
return &Payload{ return &Payload{
Flags: flags, Flags: flags,
Length: uint32(st.TotalPayloadSize), Length: uint32(p.st.TotalPayloadSize),
Data: nextChunk, Data: nextChunk,
}, st }
} }

View File

@ -0,0 +1,7 @@
package tls
import "crypto/tls"
type Settings struct {
Config *tls.Config
}

View File

@ -18,7 +18,8 @@ type State struct {
ContextCancel context.CancelFunc ContextCancel context.CancelFunc
} }
func NewState() *State { func NewState(c tctx) *State {
c.Log().Debug("TLS: new state")
return &State{ return &State{
RemainingChunks: make([][]byte, 0), RemainingChunks: make([][]byte, 0),
} }

View File

@ -1,12 +1,14 @@
package radius package radius
import ( import (
ttls "crypto/tls"
"encoding/base64" "encoding/base64"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow" "goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/radius/eap" "goauthentik.io/internal/outpost/radius/eap"
"goauthentik.io/internal/outpost/radius/eap/tls"
"goauthentik.io/internal/outpost/radius/metrics" "goauthentik.io/internal/outpost/radius/metrics"
"layeh.com/radius" "layeh.com/radius"
"layeh.com/radius/rfc2865" "layeh.com/radius/rfc2865"
@ -122,7 +124,24 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
} }
func (pi *ProviderInstance) GetEAPSettings() eap.Settings { func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
// Testing
cert, err := ttls.LoadX509KeyPair(
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
)
if err != nil {
panic(err)
}
return eap.Settings{ return eap.Settings{
ChallengesToOffer: []eap.Type{eap.TypeTLS}, ProtocolsToOffer: []eap.Type{eap.TypeTLS},
ProtocolSettings: map[eap.Type]interface{}{
eap.TypeTLS: tls.Settings{
Config: &ttls.Config{
Certificates: []ttls.Certificate{cert},
ClientAuth: ttls.RequireAnyClientCert,
},
},
},
} }
} }