From 82c177b7eb2770a6709e8fcba05b8c97dee17a6f Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Wed, 21 May 2025 02:00:12 +0200 Subject: [PATCH] try to make this work Signed-off-by: Jens Langhammer --- internal/outpost/radius/eap/context.go | 45 +++++++++---------- internal/outpost/radius/eap/handler.go | 27 ++++++----- internal/outpost/radius/eap/packet.go | 17 ++++--- .../outpost/radius/eap/protocol/context.go | 2 + .../radius/eap/protocol/eap/payload.go | 2 +- .../outpost/radius/eap/protocol/packet.go | 4 ++ .../radius/eap/protocol/peap/payload.go | 30 +++++++++++-- .../outpost/radius/eap/protocol/peap/state.go | 4 ++ .../outpost/radius/eap/protocol/tls/inner.go | 11 ++++- .../radius/eap/protocol/tls/payload.go | 12 +++-- 10 files changed, 101 insertions(+), 53 deletions(-) create mode 100644 internal/outpost/radius/eap/protocol/peap/state.go diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go index e32f8f7f2d..0202ab38cd 100644 --- a/internal/outpost/radius/eap/context.go +++ b/internal/outpost/radius/eap/context.go @@ -16,28 +16,27 @@ type context struct { endModifier func(p *radius.Packet) *radius.Packet } -func (ctx context) Packet() *radius.Request { - return ctx.req -} +func (ctx *context) Packet() *radius.Request { return ctx.req } +func (ctx *context) ProtocolSettings() interface{} { return ctx.settings } +func (ctx *context) StateForProtocol(p protocol.Type) interface{} { return ctx.typeState[p] } +func (ctx *context) GetProtocolState() interface{} { return ctx.state } +func (ctx *context) SetProtocolState(st interface{}) { ctx.state = st } +func (ctx *context) IsProtocolStart() bool { return ctx.state == nil } +func (ctx *context) Log() *log.Entry { return ctx.log } -func (ctx context) ProtocolSettings() interface{} { - return ctx.settings -} - -func (ctx *context) StateForProtocol(p protocol.Type) interface{} { - return ctx.typeState[p] -} - -func (ctx *context) GetProtocolState() interface{} { - return ctx.state -} - -func (ctx *context) SetProtocolState(st interface{}) { - ctx.state = st -} - -func (ctx *context) IsProtocolStart() bool { - return ctx.state == nil +func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context { + log.Debug("foo") + log.Debugf("%+v", ctx.typeState[protocol.Type(13)]) + log.Debugf("%+v", ctx.typeState[protocol.Type(25)]) + return &context{ + req: ctx.req, + state: ctx.StateForProtocol(p), + typeState: ctx.typeState, + log: ctx.log, + settings: ctx.settings, + endStatus: ctx.endStatus, + endModifier: ctx.endModifier, + } } func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { @@ -52,7 +51,3 @@ func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packe } ctx.endModifier = mf } - -func (ctx context) Log() *log.Entry { - return ctx.log -} diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index 8209c4a2da..c0a9bcc041 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -97,8 +97,9 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { }, err } - next := func() (*eap.Payload, error) { + next := func(oldProtocol protocol.Type) (*eap.Payload, error) { st.ProtocolIndex += 1 + delete(st.TypeState, oldProtocol) p.stm.SetEAPState(p.state, st) return p.handleInner(r) } @@ -106,26 +107,28 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok { log.Debug("EAP: received NAK, trying next protocol") p.eap.Payload = nil - return next() + log.Debug(st.ProtocolPriority[st.ProtocolIndex]) + return next(st.ProtocolPriority[st.ProtocolIndex]) } - np, _ := emptyPayload(p.stm, nextChallengeToOffer) + np, t, _ := emptyPayload(p.stm, nextChallengeToOffer) ctx := &context{ - req: r, + req: r, + // Always write to the state of the outer protocol state: st.TypeState[np.Type()], typeState: st.TypeState, - log: log.WithField("type", fmt.Sprintf("%T", np)), - settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()], + log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t), + settings: p.stm.GetEAPSettings().ProtocolSettings[t], } if !np.Offerable() { ctx.log.Debug("EAP: protocol not offerable, skipping") - return next() + return next(np.Type()) } ctx.log.Debug("EAP: Passing to protocol") - res := p.GetChallengeForType(ctx, np) - st.TypeState[np.Type()] = ctx.GetProtocolState() + res := p.GetChallengeForType(ctx, np, t) + st.TypeState[t] = ctx.GetProtocolState() p.stm.SetEAPState(p.state, st) if ctx.endModifier != nil { @@ -141,17 +144,17 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { res.ID -= 1 case protocol.StatusNextProtocol: ctx.log.Debug("EAP: Protocol ended, starting next protocol") - return next() + return next(np.Type()) case protocol.StatusUnknown: } return res, nil } -func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *eap.Payload { +func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload, t protocol.Type) *eap.Payload { res := &eap.Payload{ Code: protocol.CodeRequest, ID: p.eap.ID + 1, - MsgType: np.Type(), + MsgType: t, } var payload any if ctx.IsProtocolStart() { diff --git a/internal/outpost/radius/eap/packet.go b/internal/outpost/radius/eap/packet.go index 27297cb159..7aec9abb0a 100644 --- a/internal/outpost/radius/eap/packet.go +++ b/internal/outpost/radius/eap/packet.go @@ -15,13 +15,20 @@ type Packet struct { endModifier func(p *radius.Packet) *radius.Packet } -func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) { +func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) { for _, cons := range stm.GetEAPSettings().Protocols { - if np := cons(); np.Type() == t { - return np, nil + np := cons() + if np.Type() == t { + return np, np.Type(), nil + } + // If the protocol has an inner protocol, return the original type but the code for the inner protocol + if i, ok := np.(protocol.Inner); ok { + if ii := i.HasInner(); ii != nil { + return np, ii.Type(), nil + } } } - return nil, fmt.Errorf("unsupported EAP type %d", t) + return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t) } func Decode(stm StateManager, raw []byte) (*Packet, error) { @@ -38,7 +45,7 @@ func Decode(stm StateManager, raw []byte) (*Packet, error) { if err != nil { return nil, err } - p, err := emptyPayload(stm, packet.eap.MsgType) + p, _, err := emptyPayload(stm, packet.eap.MsgType) if err != nil { return nil, err } diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go index 58791ec7ea..5b905932b9 100644 --- a/internal/outpost/radius/eap/protocol/context.go +++ b/internal/outpost/radius/eap/protocol/context.go @@ -19,6 +19,8 @@ type Context interface { ProtocolSettings() interface{} + ForInnerProtocol(p Type) Context + StateForProtocol(p Type) interface{} GetProtocolState() interface{} SetProtocolState(interface{}) diff --git a/internal/outpost/radius/eap/protocol/eap/payload.go b/internal/outpost/radius/eap/protocol/eap/payload.go index d179a1aa81..9a9434e344 100644 --- a/internal/outpost/radius/eap/protocol/eap/payload.go +++ b/internal/outpost/radius/eap/protocol/eap/payload.go @@ -47,7 +47,7 @@ func (packet *Payload) Decode(raw []byte) error { if packet.Payload == nil { return nil } - log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw") + log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Trace("EAP: decode raw") err := packet.Payload.Decode(raw[5:]) if err != nil { return err diff --git a/internal/outpost/radius/eap/protocol/packet.go b/internal/outpost/radius/eap/protocol/packet.go index 6da6c88c93..62d30219e7 100644 --- a/internal/outpost/radius/eap/protocol/packet.go +++ b/internal/outpost/radius/eap/protocol/packet.go @@ -8,6 +8,10 @@ type Payload interface { Offerable() bool } +type Inner interface { + HasInner() Payload +} + type Type uint8 type Code uint8 diff --git a/internal/outpost/radius/eap/protocol/peap/payload.go b/internal/outpost/radius/eap/protocol/peap/payload.go index eb3b960c60..8def344020 100644 --- a/internal/outpost/radius/eap/protocol/peap/payload.go +++ b/internal/outpost/radius/eap/protocol/peap/payload.go @@ -21,33 +21,55 @@ func Protocol() protocol.Payload { type Payload struct { Inner protocol.Payload + + eap *eap.Payload + st *State + raw []byte } func (p *Payload) Type() protocol.Type { return TypePEAP } +func (p *Payload) HasInner() protocol.Payload { + return p.Inner +} + func (p *Payload) Decode(raw []byte) error { log.WithField("raw", debug.FormatBytes(raw)).Debug("PEAP: Decode") + p.raw = raw return nil } func (p *Payload) Encode() ([]byte, error) { - log.Debug("PEAP: Encode") - return []byte{}, nil + return p.eap.Encode() } func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { + defer func() { + ctx.SetProtocolState(p.st) + }() + eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State) - if !ctx.IsProtocolStart() { + + if ctx.IsProtocolStart() { ctx.Log().Debug("PEAP: Protocol start") + p.st = &State{} return &eap.Payload{ Code: protocol.CodeRequest, - ID: eapState.PacketID, + ID: eapState.PacketID + 1, MsgType: identity.TypeIdentity, Payload: &identity.Payload{}, } } + p.st = ctx.GetProtocolState().(*State) + + ep := &eap.Payload{} + err := ep.Decode(p.raw) + if err != nil { + ctx.Log().WithError(err).Warning("PEAP: failed to decode inner EAP") + return &Payload{} + } return &Payload{} } diff --git a/internal/outpost/radius/eap/protocol/peap/state.go b/internal/outpost/radius/eap/protocol/peap/state.go new file mode 100644 index 0000000000..bb7c874633 --- /dev/null +++ b/internal/outpost/radius/eap/protocol/peap/state.go @@ -0,0 +1,4 @@ +package peap + +type State struct { +} diff --git a/internal/outpost/radius/eap/protocol/tls/inner.go b/internal/outpost/radius/eap/protocol/tls/inner.go index b9d4bcead6..5dde05b1e5 100644 --- a/internal/outpost/radius/eap/protocol/tls/inner.go +++ b/internal/outpost/radius/eap/protocol/tls/inner.go @@ -13,9 +13,16 @@ func (p *Payload) innerHandler(ctx protocol.Context) { ctx.EndInnerProtocol(protocol.StatusError, nil) return } - pl := p.Inner.Handle(ctx) + pl := p.Inner.Handle(ctx.ForInnerProtocol(p.Inner.Type())) enc, err := pl.Encode() - p.st.TLS.Write(enc) + if err != nil { + ctx.Log().WithError(err).Warning("failed to encode inner protocol") + } + // p.st.Conn.expectedWriterByteCount = len(enc) + _, err = p.st.TLS.Write(enc) + if err != nil { + ctx.Log().WithError(err).Warning("failed to write to TLS") + } // return &Payload{ // Data: enc, // } diff --git a/internal/outpost/radius/eap/protocol/tls/payload.go b/internal/outpost/radius/eap/protocol/tls/payload.go index ddc6b5747a..ab6f3119b6 100644 --- a/internal/outpost/radius/eap/protocol/tls/payload.go +++ b/internal/outpost/radius/eap/protocol/tls/payload.go @@ -36,12 +36,16 @@ type Payload struct { } func (p *Payload) Type() protocol.Type { - if p.Inner != nil { - return p.Inner.Type() - } + // if p.inner != nil { + // return p.inner.Type() + // } return TypeTLS } +func (p *Payload) HasInner() protocol.Payload { + return p.Inner +} + func (p *Payload) Offerable() bool { return true } @@ -58,7 +62,7 @@ func (p *Payload) Decode(raw []byte) error { } else { p.Data = raw[0:] } - log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw") + log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Trace("TLS: decode raw") return nil }