@ -8,7 +8,7 @@ import (
|
||||
|
||||
type context struct {
|
||||
req *radius.Request
|
||||
state interface{}
|
||||
rootPayload protocol.Payload
|
||||
typeState map[protocol.Type]any
|
||||
log *log.Entry
|
||||
settings interface{}
|
||||
@ -16,21 +16,17 @@ type context struct {
|
||||
endModifier func(p *radius.Packet) *radius.Packet
|
||||
}
|
||||
|
||||
func (ctx *context) Packet() *radius.Request { return ctx.req }
|
||||
func (ctx *context) ProtocolSettings() interface{} { return ctx.settings }
|
||||
func (ctx *context) StateForProtocol(p protocol.Type) interface{} { return ctx.typeState[p] }
|
||||
func (ctx *context) GetProtocolState() interface{} { return ctx.state }
|
||||
func (ctx *context) SetProtocolState(st interface{}) { ctx.state = st }
|
||||
func (ctx *context) IsProtocolStart() bool { return ctx.state == nil }
|
||||
func (ctx *context) Log() *log.Entry { return ctx.log }
|
||||
func (ctx *context) RootPayload() protocol.Payload { return ctx.rootPayload }
|
||||
func (ctx *context) Packet() *radius.Request { return ctx.req }
|
||||
func (ctx *context) ProtocolSettings() interface{} { return ctx.settings }
|
||||
func (ctx *context) GetProtocolState(p protocol.Type) interface{} { return ctx.typeState[p] }
|
||||
func (ctx *context) SetProtocolState(p protocol.Type, st interface{}) { ctx.typeState[p] = st }
|
||||
func (ctx *context) IsProtocolStart(p protocol.Type) bool { return ctx.typeState[p] == nil }
|
||||
func (ctx *context) Log() *log.Entry { return ctx.log }
|
||||
|
||||
func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context {
|
||||
log.Debug("foo")
|
||||
log.Debugf("%+v", ctx.typeState[protocol.Type(13)])
|
||||
log.Debugf("%+v", ctx.typeState[protocol.Type(25)])
|
||||
return &context{
|
||||
req: ctx.req,
|
||||
state: ctx.StateForProtocol(p),
|
||||
typeState: ctx.typeState,
|
||||
log: ctx.log,
|
||||
settings: ctx.settings,
|
||||
|
@ -80,15 +80,6 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
||||
st = BlankState(p.stm.GetEAPSettings())
|
||||
}
|
||||
|
||||
// FIXME: Statically call Handle of root EAP packet to make its data accessible
|
||||
ectx := &context{
|
||||
state: st.TypeState[eap.TypeEAP],
|
||||
log: log.WithField("type", fmt.Sprintf("%T", &eap.Payload{})),
|
||||
}
|
||||
p.eap.Handle(ectx)
|
||||
st.TypeState[eap.TypeEAP] = ectx.GetProtocolState()
|
||||
p.stm.SetEAPState(p.state, st)
|
||||
|
||||
nextChallengeToOffer, err := st.GetNextProtocol()
|
||||
if err != nil {
|
||||
return &eap.Payload{
|
||||
@ -97,9 +88,9 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
||||
}, err
|
||||
}
|
||||
|
||||
next := func(oldProtocol protocol.Type) (*eap.Payload, error) {
|
||||
next := func() (*eap.Payload, error) {
|
||||
st.ProtocolIndex += 1
|
||||
delete(st.TypeState, oldProtocol)
|
||||
st.TypeState = map[protocol.Type]any{}
|
||||
p.stm.SetEAPState(p.state, st)
|
||||
return p.handleInner(r)
|
||||
}
|
||||
@ -107,28 +98,25 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
||||
if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok {
|
||||
log.Debug("EAP: received NAK, trying next protocol")
|
||||
p.eap.Payload = nil
|
||||
log.Debug(st.ProtocolPriority[st.ProtocolIndex])
|
||||
return next(st.ProtocolPriority[st.ProtocolIndex])
|
||||
return next()
|
||||
}
|
||||
|
||||
np, t, _ := emptyPayload(p.stm, nextChallengeToOffer)
|
||||
|
||||
ctx := &context{
|
||||
req: r,
|
||||
// Always write to the state of the outer protocol
|
||||
state: st.TypeState[np.Type()],
|
||||
typeState: st.TypeState,
|
||||
log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t),
|
||||
settings: p.stm.GetEAPSettings().ProtocolSettings[t],
|
||||
req: r,
|
||||
rootPayload: p.eap,
|
||||
typeState: st.TypeState,
|
||||
log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t),
|
||||
settings: p.stm.GetEAPSettings().ProtocolSettings[t],
|
||||
}
|
||||
if !np.Offerable() {
|
||||
ctx.log.Debug("EAP: protocol not offerable, skipping")
|
||||
return next(np.Type())
|
||||
return next()
|
||||
}
|
||||
ctx.log.Debug("EAP: Passing to protocol")
|
||||
|
||||
res := p.GetChallengeForType(ctx, np, t)
|
||||
st.TypeState[t] = ctx.GetProtocolState()
|
||||
p.stm.SetEAPState(p.state, st)
|
||||
|
||||
if ctx.endModifier != nil {
|
||||
@ -144,7 +132,7 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
||||
res.ID -= 1
|
||||
case protocol.StatusNextProtocol:
|
||||
ctx.log.Debug("EAP: Protocol ended, starting next protocol")
|
||||
return next(np.Type())
|
||||
return next()
|
||||
case protocol.StatusUnknown:
|
||||
}
|
||||
return res, nil
|
||||
@ -157,7 +145,7 @@ func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload, t protoc
|
||||
MsgType: t,
|
||||
}
|
||||
var payload any
|
||||
if ctx.IsProtocolStart() {
|
||||
if ctx.IsProtocolStart(t) {
|
||||
p.eap.Payload = np
|
||||
p.eap.Payload.Decode(p.eap.RawPayload)
|
||||
}
|
||||
|
@ -16,16 +16,14 @@ const (
|
||||
|
||||
type Context interface {
|
||||
Packet() *radius.Request
|
||||
RootPayload() Payload
|
||||
|
||||
ProtocolSettings() interface{}
|
||||
|
||||
ForInnerProtocol(p Type) Context
|
||||
GetProtocolState(p Type) interface{}
|
||||
SetProtocolState(p Type, s interface{})
|
||||
IsProtocolStart(p Type) bool
|
||||
|
||||
StateForProtocol(p Type) interface{}
|
||||
GetProtocolState() interface{}
|
||||
SetProtocolState(interface{})
|
||||
|
||||
IsProtocolStart() bool
|
||||
EndInnerProtocol(Status, func(p *radius.Packet) *radius.Packet)
|
||||
|
||||
Log() *log.Entry
|
||||
|
@ -76,8 +76,5 @@ func (p *Payload) Encode() ([]byte, error) {
|
||||
|
||||
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
ctx.Log().Debug("EAP: Handle")
|
||||
ctx.SetProtocolState(&State{
|
||||
PacketID: ip.ID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func (ip *Payload) Encode() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart() {
|
||||
if ctx.IsProtocolStart(TypeIdentity) {
|
||||
ctx.EndInnerProtocol(protocol.StatusNextProtocol, nil)
|
||||
}
|
||||
return nil
|
||||
|
@ -26,7 +26,7 @@ func (ln *Payload) Encode() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (ln *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
if ctx.IsProtocolStart() {
|
||||
if ctx.IsProtocolStart(TypeLegacyNAK) {
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
}
|
||||
return nil
|
||||
|
@ -47,22 +47,22 @@ func (p *Payload) Encode() ([]byte, error) {
|
||||
|
||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
defer func() {
|
||||
ctx.SetProtocolState(p.st)
|
||||
ctx.SetProtocolState(TypePEAP, p.st)
|
||||
}()
|
||||
|
||||
eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State)
|
||||
rootEap := ctx.RootPayload().(*eap.Payload)
|
||||
|
||||
if ctx.IsProtocolStart() {
|
||||
if ctx.IsProtocolStart(TypePEAP) {
|
||||
ctx.Log().Debug("PEAP: Protocol start")
|
||||
p.st = &State{}
|
||||
return &eap.Payload{
|
||||
Code: protocol.CodeRequest,
|
||||
ID: eapState.PacketID + 1,
|
||||
ID: rootEap.ID + 1,
|
||||
MsgType: identity.TypeIdentity,
|
||||
Payload: &identity.Payload{},
|
||||
}
|
||||
}
|
||||
p.st = ctx.GetProtocolState().(*State)
|
||||
p.st = ctx.GetProtocolState(TypePEAP).(*State)
|
||||
|
||||
ep := &eap.Payload{}
|
||||
err := ep.Decode(p.raw)
|
||||
|
@ -13,7 +13,7 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
|
||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||
return
|
||||
}
|
||||
pl := p.Inner.Handle(ctx.ForInnerProtocol(p.Inner.Type()))
|
||||
pl := p.Inner.Handle(ctx)
|
||||
enc, err := pl.Encode()
|
||||
if err != nil {
|
||||
ctx.Log().WithError(err).Warning("failed to encode inner protocol")
|
||||
|
@ -87,15 +87,15 @@ func (p *Payload) Encode() ([]byte, error) {
|
||||
|
||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||
defer func() {
|
||||
ctx.SetProtocolState(p.st)
|
||||
ctx.SetProtocolState(TypeTLS, p.st)
|
||||
}()
|
||||
if ctx.IsProtocolStart() {
|
||||
if ctx.IsProtocolStart(TypeTLS) {
|
||||
p.st = NewState(ctx).(*State)
|
||||
return &Payload{
|
||||
Flags: FlagTLSStart,
|
||||
}
|
||||
}
|
||||
p.st = ctx.GetProtocolState().(*State)
|
||||
p.st = ctx.GetProtocolState(TypeTLS).(*State)
|
||||
|
||||
if p.st.TLS == nil {
|
||||
p.tlsInit(ctx)
|
||||
|
Reference in New Issue
Block a user