From ac887840891a819416a4b20c2f7d46c828b09fbb Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Fri, 16 May 2025 14:23:11 +0200 Subject: [PATCH] maybe? Signed-off-by: Jens Langhammer --- internal/outpost/radius/eap/context.go | 29 ++++ internal/outpost/radius/eap/handler.go | 12 +- .../outpost/radius/eap/protocol/context.go | 18 +++ internal/outpost/radius/eap/state.go | 6 +- internal/outpost/radius/eap/tls/payload.go | 148 ++++++++---------- internal/outpost/radius/eap/tls/settings.go | 7 + internal/outpost/radius/eap/tls/state.go | 3 +- .../outpost/radius/handle_access_request.go | 21 ++- 8 files changed, 151 insertions(+), 93 deletions(-) create mode 100644 internal/outpost/radius/eap/context.go create mode 100644 internal/outpost/radius/eap/protocol/context.go create mode 100644 internal/outpost/radius/eap/tls/settings.go diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go new file mode 100644 index 0000000000..2a9291758b --- /dev/null +++ b/internal/outpost/radius/eap/context.go @@ -0,0 +1,29 @@ +package eap + +import ( + "github.com/sirupsen/logrus" + "layeh.com/radius" +) + +type context[TState any, TSettings any] struct { +} + +func (ctx context[TState, TSettings]) ProtocolSettings() TSettings { + return 0 +} + +func (ctx context[TState, TSettings]) GetProtocolState(def func(context[TState, TSettings]) TState) TState { + return nil +} + +func (ctx context[TState, TSettings]) SetProtocolState(TState) { + +} + +func (ctx context[TState, TSettings]) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) { + +} + +func (ctx context[TState, TSettings]) Log() *logrus.Entry { + return nil +} diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 0457d0ff80..9f2212ad68 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -28,7 +28,10 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac panic("No more challenges") } nextChallengeToOffer := st.ChallengesToOffer[0] - res, newState := p.GetChallengeForType(st, nextChallengeToOffer) + + ctx := context{} + + res, newState := p.GetChallengeForType(ctx, nextChallengeToOffer) stm.SetEAPState(rst, newState) rres := r.Response(radius.CodeAccessChallenge) @@ -52,21 +55,22 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac } } -func (p *Packet) GetChallengeForType(st *State, t Type) (*Packet, *State) { +func (p *Packet) GetChallengeForType(ctx context[any, any], t Type) *Packet { res := &Packet{ code: CodeRequest, id: p.id + 1, msgType: t, } var payload any - var tst any switch t { case TypeTLS: + // TODO: rewrite this if _, ok := p.Payload.(*tls.Payload); !ok { p.Payload = &tls.Payload{} p.Payload.Decode(p.rawPayload) } - payload, tst = p.Payload.(*tls.Payload).Handle(st.TypeState[t]) + // this + payload = p.Payload.(*tls.Payload).Handle(ctx) } st.TypeState[t] = tst res.Payload = payload.(protocol.Payload) diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go new file mode 100644 index 0000000000..4e2446fcab --- /dev/null +++ b/internal/outpost/radius/eap/protocol/context.go @@ -0,0 +1,18 @@ +package protocol + +import ( + log "github.com/sirupsen/logrus" + "layeh.com/radius" +) + +type Context[TState any, TSettings any] interface { + // GlobalState() + + ProtocolSettings() TSettings + GetProtocolState(def func(Context[TState, TSettings]) TState) TState + SetProtocolState(TState) + + EndInnerProtocol(func(p *radius.Packet) *radius.Packet) + + Log() *log.Entry +} diff --git a/internal/outpost/radius/eap/state.go b/internal/outpost/radius/eap/state.go index 6c1b65dad6..cf0d59fa1d 100644 --- a/internal/outpost/radius/eap/state.go +++ b/internal/outpost/radius/eap/state.go @@ -3,8 +3,8 @@ package eap import "slices" type Settings struct { - ChallengesToOffer []Type - ChallengeSettings map[Type]interface{} + ProtocolsToOffer []Type + ProtocolSettings map[Type]interface{} } type StateManager interface { @@ -20,7 +20,7 @@ type State struct { func BlankState(settings Settings) *State { return &State{ - ChallengesToOffer: slices.Clone(settings.ChallengesToOffer), + ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer), TypeState: map[Type]any{}, } } diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index 97b9c8ec16..47ad9c9cb5 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -12,31 +12,18 @@ import ( "goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/protocol" "layeh.com/radius" - "layeh.com/radius/rfc2865" "layeh.com/radius/vendors/microsoft" ) const maxChunkSize = 1000 const staleConnectionTimeout = 10 -var certs = []tls.Certificate{} - -func init() { - // Testing - cert, err := tls.LoadX509KeyPair( - "../t/ca/out/cert_jens-mbp.lab.beryju.org.pem", - "../t/ca/out/cert_jens-mbp.lab.beryju.org.key", - ) - if err != nil { - panic(err) - } - certs = append(certs, cert) -} - type Payload struct { Flags Flag Length uint32 Data []byte + + st *State } func (p *Payload) Decode(raw []byte) error { @@ -74,92 +61,85 @@ func (p *Payload) Encode() ([]byte, error) { return buff, nil } -func (p *Payload) Handle(stt any) (protocol.Payload, *State) { - if stt == nil { - log.Debug("TLS: new state") - stt = NewState() - } - st := stt.(*State) - if !st.HasStarted { +type tctx = protocol.Context[*State, Settings] + +func (p *Payload) Handle(ctx tctx) protocol.Payload { + p.st = ctx.GetProtocolState(NewState) + defer ctx.SetProtocolState(p.st) + if !p.st.HasStarted { log.Debug("TLS: handshake starting") - st.HasStarted = true + p.st.HasStarted = true return &Payload{ Flags: FlagTLSStart, - }, st + } } - if st.TLS == nil { - st = p.tlsInit(st) + if p.st.TLS == nil { + p.tlsInit(ctx) } else if len(p.Data) > 0 { log.Debug("TLS: Updating buffer with new TLS data from packet") - if p.Flags&FlagLengthIncluded != 0 && st.Conn.expectedWriterByteCount == 0 { + if p.Flags&FlagLengthIncluded != 0 && p.st.Conn.expectedWriterByteCount == 0 { log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length) - st.Conn.expectedWriterByteCount = int(p.Length) + p.st.Conn.expectedWriterByteCount = int(p.Length) } else if p.Flags&FlagLengthIncluded != 0 { log.Debug("TLS: No length included, not buffering") - st.Conn.expectedWriterByteCount = 0 + p.st.Conn.expectedWriterByteCount = 0 } - st.Conn.UpdateData(p.Data) - if !st.Conn.NeedsMoreData() { + p.st.Conn.UpdateData(p.Data) + if !p.st.Conn.NeedsMoreData() { // Wait for outbound data to be available - st.Conn.OutboundData() + p.st.Conn.OutboundData() } } // If we need more data, send the client the go-ahead - if st.Conn.NeedsMoreData() { + if p.st.Conn.NeedsMoreData() { return &Payload{ Flags: FlagNone, Length: 0, Data: []byte{}, - }, st + } } - if st.HasMore() { - return p.sendNextChunk(st) + if p.st.HasMore() { + return p.sendNextChunk() } - if st.Conn.writer.Len() == 0 && st.HandshakeDone { - defer st.ContextCancel() - return protocol.EmptyPayload{ - ModifyPacket: func(p *radius.Packet) *radius.Packet { - p.Code = radius.CodeAccessAccept - microsoft.MSMPPERecvKey_Set(p, st.MPPEKey[:32]) - microsoft.MSMPPESendKey_Set(p, st.MPPEKey[64:64+32]) - rfc2865.UserName_SetString(p, "foo") - rfc2865.FramedMTU_Set(p, rfc2865.FramedMTU(1400)) - return p - }, - }, st + if p.st.Conn.writer.Len() == 0 && p.st.HandshakeDone { + defer p.st.ContextCancel() + ctx.EndInnerProtocol(func(r *radius.Packet) *radius.Packet { + r.Code = radius.CodeAccessAccept + microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32]) + microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32]) + return r + }) + return nil } - return p.startChunkedTransfer(st.Conn.OutboundData(), st) + return p.startChunkedTransfer(p.st.Conn.OutboundData()) } -func (p *Payload) tlsInit(st *State) *State { +func (p *Payload) tlsInit(ctx tctx) { log.Debug("TLS: no TLS connection in state yet, starting connection") - st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) - st.Conn = NewBuffConn(p.Data, st.Context) - st.TLS = tls.Server(st.Conn, &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - log.Debugf("TLS: ClientHello: %+v\n", ch) - st.ClientHello = ch - return nil, nil - }, - ClientAuth: tls.RequireAnyClientCert, - Certificates: certs, - }) + p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) + p.st.Conn = NewBuffConn(p.Data, p.st.Context) + cfg := ctx.ProtocolSettings().Config.Clone() + cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + log.Debugf("TLS: ClientHello: %+v\n", chi) + p.st.ClientHello = chi + return nil, nil + } + p.st.TLS = tls.Server(p.st.Conn, cfg) go func() { - err := st.TLS.HandshakeContext(st.Context) + err := p.st.TLS.HandshakeContext(p.st.Context) if err != nil { log.WithError(err).Debug("TLS: Handshake error") // TODO: Send a NAK to the client return } log.Debug("TLS: handshake done") - p.tlsHandshakeFinished(st) + p.tlsHandshakeFinished() }() - return st } -func (p *Payload) tlsHandshakeFinished(st *State) { - cs := st.TLS.ConnectionState() +func (p *Payload) tlsHandshakeFinished() { + cs := p.st.TLS.ConnectionState() label := "client EAP encryption" var context []byte switch cs.Version { @@ -176,46 +156,46 @@ func (p *Payload) tlsHandshakeFinished(st *State) { } ksm, err := cs.ExportKeyingMaterial(label, context, 64+64) log.Debugf("TLS: ksm % x %v", ksm, err) - st.MPPEKey = ksm - st.HandshakeDone = true + p.st.MPPEKey = ksm + p.st.HandshakeDone = true } -func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) { +func (p *Payload) startChunkedTransfer(data []byte) *Payload { if len(data) > maxChunkSize { log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked") - st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...) - st.TotalPayloadSize = len(data) - return p.sendNextChunk(st) + p.st.RemainingChunks = append(p.st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...) + p.st.TotalPayloadSize = len(data) + return p.sendNextChunk() } log.WithField("length", len(data)).Debug("TLS: Sending data un-chunked") - st.Conn.writer.Reset() + p.st.Conn.writer.Reset() return &Payload{ Flags: FlagLengthIncluded, Length: uint32(len(data)), Data: data, - }, st + } } -func (p *Payload) sendNextChunk(st *State) (*Payload, *State) { - nextChunk := st.RemainingChunks[0] +func (p *Payload) sendNextChunk() *Payload { + nextChunk := p.st.RemainingChunks[0] log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk") - st.RemainingChunks = st.RemainingChunks[1:] + p.st.RemainingChunks = p.st.RemainingChunks[1:] flags := FlagLengthIncluded - if st.HasMore() { - log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left") + if p.st.HasMore() { + log.WithField("chunks", len(p.st.RemainingChunks)).Debug("TLS: More chunks left") flags += FlagMoreFragments } else { // Last chunk, reset the connection buffers and pending payload size defer func() { log.Debug("TLS: Sent last chunk") - st.Conn.writer.Reset() - st.TotalPayloadSize = 0 + p.st.Conn.writer.Reset() + p.st.TotalPayloadSize = 0 }() } - log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size") + log.WithField("length", p.st.TotalPayloadSize).Debug("TLS: Total payload size") return &Payload{ Flags: flags, - Length: uint32(st.TotalPayloadSize), + Length: uint32(p.st.TotalPayloadSize), Data: nextChunk, - }, st + } } diff --git a/internal/outpost/radius/eap/tls/settings.go b/internal/outpost/radius/eap/tls/settings.go new file mode 100644 index 0000000000..8d7b608434 --- /dev/null +++ b/internal/outpost/radius/eap/tls/settings.go @@ -0,0 +1,7 @@ +package tls + +import "crypto/tls" + +type Settings struct { + Config *tls.Config +} diff --git a/internal/outpost/radius/eap/tls/state.go b/internal/outpost/radius/eap/tls/state.go index adf23e918b..ff1ad37446 100644 --- a/internal/outpost/radius/eap/tls/state.go +++ b/internal/outpost/radius/eap/tls/state.go @@ -18,7 +18,8 @@ type State struct { ContextCancel context.CancelFunc } -func NewState() *State { +func NewState(c tctx) *State { + c.Log().Debug("TLS: new state") return &State{ RemainingChunks: make([][]byte, 0), } diff --git a/internal/outpost/radius/handle_access_request.go b/internal/outpost/radius/handle_access_request.go index df3b7a1009..67631b56fb 100644 --- a/internal/outpost/radius/handle_access_request.go +++ b/internal/outpost/radius/handle_access_request.go @@ -1,12 +1,14 @@ package radius import ( + ttls "crypto/tls" "encoding/base64" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/flow" "goauthentik.io/internal/outpost/radius/eap" + "goauthentik.io/internal/outpost/radius/eap/tls" "goauthentik.io/internal/outpost/radius/metrics" "layeh.com/radius" "layeh.com/radius/rfc2865" @@ -122,7 +124,24 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) { } func (pi *ProviderInstance) GetEAPSettings() eap.Settings { + // Testing + cert, err := ttls.LoadX509KeyPair( + "../t/ca/out/cert_jens-mbp.lab.beryju.org.pem", + "../t/ca/out/cert_jens-mbp.lab.beryju.org.key", + ) + if err != nil { + panic(err) + } + return eap.Settings{ - ChallengesToOffer: []eap.Type{eap.TypeTLS}, + ProtocolsToOffer: []eap.Type{eap.TypeTLS}, + ProtocolSettings: map[eap.Type]interface{}{ + eap.TypeTLS: tls.Settings{ + Config: &ttls.Config{ + Certificates: []ttls.Certificate{cert}, + ClientAuth: ttls.RequireAnyClientCert, + }, + }, + }, } }