try to make this work
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -16,28 +16,27 @@ type context struct {
|
|||||||
endModifier func(p *radius.Packet) *radius.Packet
|
endModifier func(p *radius.Packet) *radius.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx context) Packet() *radius.Request {
|
func (ctx *context) Packet() *radius.Request { return ctx.req }
|
||||||
return ctx.req
|
func (ctx *context) ProtocolSettings() interface{} { return ctx.settings }
|
||||||
}
|
func (ctx *context) StateForProtocol(p protocol.Type) interface{} { return ctx.typeState[p] }
|
||||||
|
func (ctx *context) GetProtocolState() interface{} { return ctx.state }
|
||||||
|
func (ctx *context) SetProtocolState(st interface{}) { ctx.state = st }
|
||||||
|
func (ctx *context) IsProtocolStart() bool { return ctx.state == nil }
|
||||||
|
func (ctx *context) Log() *log.Entry { return ctx.log }
|
||||||
|
|
||||||
func (ctx context) ProtocolSettings() interface{} {
|
func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context {
|
||||||
return ctx.settings
|
log.Debug("foo")
|
||||||
}
|
log.Debugf("%+v", ctx.typeState[protocol.Type(13)])
|
||||||
|
log.Debugf("%+v", ctx.typeState[protocol.Type(25)])
|
||||||
func (ctx *context) StateForProtocol(p protocol.Type) interface{} {
|
return &context{
|
||||||
return ctx.typeState[p]
|
req: ctx.req,
|
||||||
}
|
state: ctx.StateForProtocol(p),
|
||||||
|
typeState: ctx.typeState,
|
||||||
func (ctx *context) GetProtocolState() interface{} {
|
log: ctx.log,
|
||||||
return ctx.state
|
settings: ctx.settings,
|
||||||
}
|
endStatus: ctx.endStatus,
|
||||||
|
endModifier: ctx.endModifier,
|
||||||
func (ctx *context) SetProtocolState(st interface{}) {
|
}
|
||||||
ctx.state = st
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *context) IsProtocolStart() bool {
|
|
||||||
return ctx.state == nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) {
|
||||||
@ -52,7 +51,3 @@ func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packe
|
|||||||
}
|
}
|
||||||
ctx.endModifier = mf
|
ctx.endModifier = mf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx context) Log() *log.Entry {
|
|
||||||
return ctx.log
|
|
||||||
}
|
|
||||||
|
|||||||
@ -97,8 +97,9 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
|||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
next := func() (*eap.Payload, error) {
|
next := func(oldProtocol protocol.Type) (*eap.Payload, error) {
|
||||||
st.ProtocolIndex += 1
|
st.ProtocolIndex += 1
|
||||||
|
delete(st.TypeState, oldProtocol)
|
||||||
p.stm.SetEAPState(p.state, st)
|
p.stm.SetEAPState(p.state, st)
|
||||||
return p.handleInner(r)
|
return p.handleInner(r)
|
||||||
}
|
}
|
||||||
@ -106,26 +107,28 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
|||||||
if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok {
|
if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok {
|
||||||
log.Debug("EAP: received NAK, trying next protocol")
|
log.Debug("EAP: received NAK, trying next protocol")
|
||||||
p.eap.Payload = nil
|
p.eap.Payload = nil
|
||||||
return next()
|
log.Debug(st.ProtocolPriority[st.ProtocolIndex])
|
||||||
|
return next(st.ProtocolPriority[st.ProtocolIndex])
|
||||||
}
|
}
|
||||||
|
|
||||||
np, _ := emptyPayload(p.stm, nextChallengeToOffer)
|
np, t, _ := emptyPayload(p.stm, nextChallengeToOffer)
|
||||||
|
|
||||||
ctx := &context{
|
ctx := &context{
|
||||||
req: r,
|
req: r,
|
||||||
|
// Always write to the state of the outer protocol
|
||||||
state: st.TypeState[np.Type()],
|
state: st.TypeState[np.Type()],
|
||||||
typeState: st.TypeState,
|
typeState: st.TypeState,
|
||||||
log: log.WithField("type", fmt.Sprintf("%T", np)),
|
log: log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t),
|
||||||
settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()],
|
settings: p.stm.GetEAPSettings().ProtocolSettings[t],
|
||||||
}
|
}
|
||||||
if !np.Offerable() {
|
if !np.Offerable() {
|
||||||
ctx.log.Debug("EAP: protocol not offerable, skipping")
|
ctx.log.Debug("EAP: protocol not offerable, skipping")
|
||||||
return next()
|
return next(np.Type())
|
||||||
}
|
}
|
||||||
ctx.log.Debug("EAP: Passing to protocol")
|
ctx.log.Debug("EAP: Passing to protocol")
|
||||||
|
|
||||||
res := p.GetChallengeForType(ctx, np)
|
res := p.GetChallengeForType(ctx, np, t)
|
||||||
st.TypeState[np.Type()] = ctx.GetProtocolState()
|
st.TypeState[t] = ctx.GetProtocolState()
|
||||||
p.stm.SetEAPState(p.state, st)
|
p.stm.SetEAPState(p.state, st)
|
||||||
|
|
||||||
if ctx.endModifier != nil {
|
if ctx.endModifier != nil {
|
||||||
@ -141,17 +144,17 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) {
|
|||||||
res.ID -= 1
|
res.ID -= 1
|
||||||
case protocol.StatusNextProtocol:
|
case protocol.StatusNextProtocol:
|
||||||
ctx.log.Debug("EAP: Protocol ended, starting next protocol")
|
ctx.log.Debug("EAP: Protocol ended, starting next protocol")
|
||||||
return next()
|
return next(np.Type())
|
||||||
case protocol.StatusUnknown:
|
case protocol.StatusUnknown:
|
||||||
}
|
}
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *eap.Payload {
|
func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload, t protocol.Type) *eap.Payload {
|
||||||
res := &eap.Payload{
|
res := &eap.Payload{
|
||||||
Code: protocol.CodeRequest,
|
Code: protocol.CodeRequest,
|
||||||
ID: p.eap.ID + 1,
|
ID: p.eap.ID + 1,
|
||||||
MsgType: np.Type(),
|
MsgType: t,
|
||||||
}
|
}
|
||||||
var payload any
|
var payload any
|
||||||
if ctx.IsProtocolStart() {
|
if ctx.IsProtocolStart() {
|
||||||
|
|||||||
@ -15,13 +15,20 @@ type Packet struct {
|
|||||||
endModifier func(p *radius.Packet) *radius.Packet
|
endModifier func(p *radius.Packet) *radius.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) {
|
func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) {
|
||||||
for _, cons := range stm.GetEAPSettings().Protocols {
|
for _, cons := range stm.GetEAPSettings().Protocols {
|
||||||
if np := cons(); np.Type() == t {
|
np := cons()
|
||||||
return np, nil
|
if np.Type() == t {
|
||||||
|
return np, np.Type(), nil
|
||||||
|
}
|
||||||
|
// If the protocol has an inner protocol, return the original type but the code for the inner protocol
|
||||||
|
if i, ok := np.(protocol.Inner); ok {
|
||||||
|
if ii := i.HasInner(); ii != nil {
|
||||||
|
return np, ii.Type(), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unsupported EAP type %d", t)
|
return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Decode(stm StateManager, raw []byte) (*Packet, error) {
|
func Decode(stm StateManager, raw []byte) (*Packet, error) {
|
||||||
@ -38,7 +45,7 @@ func Decode(stm StateManager, raw []byte) (*Packet, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p, err := emptyPayload(stm, packet.eap.MsgType)
|
p, _, err := emptyPayload(stm, packet.eap.MsgType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,8 @@ type Context interface {
|
|||||||
|
|
||||||
ProtocolSettings() interface{}
|
ProtocolSettings() interface{}
|
||||||
|
|
||||||
|
ForInnerProtocol(p Type) Context
|
||||||
|
|
||||||
StateForProtocol(p Type) interface{}
|
StateForProtocol(p Type) interface{}
|
||||||
GetProtocolState() interface{}
|
GetProtocolState() interface{}
|
||||||
SetProtocolState(interface{})
|
SetProtocolState(interface{})
|
||||||
|
|||||||
@ -47,7 +47,7 @@ func (packet *Payload) Decode(raw []byte) error {
|
|||||||
if packet.Payload == nil {
|
if packet.Payload == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw")
|
log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Trace("EAP: decode raw")
|
||||||
err := packet.Payload.Decode(raw[5:])
|
err := packet.Payload.Decode(raw[5:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@ -8,6 +8,10 @@ type Payload interface {
|
|||||||
Offerable() bool
|
Offerable() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Inner interface {
|
||||||
|
HasInner() Payload
|
||||||
|
}
|
||||||
|
|
||||||
type Type uint8
|
type Type uint8
|
||||||
|
|
||||||
type Code uint8
|
type Code uint8
|
||||||
|
|||||||
@ -21,33 +21,55 @@ func Protocol() protocol.Payload {
|
|||||||
|
|
||||||
type Payload struct {
|
type Payload struct {
|
||||||
Inner protocol.Payload
|
Inner protocol.Payload
|
||||||
|
|
||||||
|
eap *eap.Payload
|
||||||
|
st *State
|
||||||
|
raw []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) Type() protocol.Type {
|
func (p *Payload) Type() protocol.Type {
|
||||||
return TypePEAP
|
return TypePEAP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Payload) HasInner() protocol.Payload {
|
||||||
|
return p.Inner
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Payload) Decode(raw []byte) error {
|
func (p *Payload) Decode(raw []byte) error {
|
||||||
log.WithField("raw", debug.FormatBytes(raw)).Debug("PEAP: Decode")
|
log.WithField("raw", debug.FormatBytes(raw)).Debug("PEAP: Decode")
|
||||||
|
p.raw = raw
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) Encode() ([]byte, error) {
|
func (p *Payload) Encode() ([]byte, error) {
|
||||||
log.Debug("PEAP: Encode")
|
return p.eap.Encode()
|
||||||
return []byte{}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
func (p *Payload) Handle(ctx protocol.Context) protocol.Payload {
|
||||||
|
defer func() {
|
||||||
|
ctx.SetProtocolState(p.st)
|
||||||
|
}()
|
||||||
|
|
||||||
eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State)
|
eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State)
|
||||||
if !ctx.IsProtocolStart() {
|
|
||||||
|
if ctx.IsProtocolStart() {
|
||||||
ctx.Log().Debug("PEAP: Protocol start")
|
ctx.Log().Debug("PEAP: Protocol start")
|
||||||
|
p.st = &State{}
|
||||||
return &eap.Payload{
|
return &eap.Payload{
|
||||||
Code: protocol.CodeRequest,
|
Code: protocol.CodeRequest,
|
||||||
ID: eapState.PacketID,
|
ID: eapState.PacketID + 1,
|
||||||
MsgType: identity.TypeIdentity,
|
MsgType: identity.TypeIdentity,
|
||||||
Payload: &identity.Payload{},
|
Payload: &identity.Payload{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
p.st = ctx.GetProtocolState().(*State)
|
||||||
|
|
||||||
|
ep := &eap.Payload{}
|
||||||
|
err := ep.Decode(p.raw)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log().WithError(err).Warning("PEAP: failed to decode inner EAP")
|
||||||
|
return &Payload{}
|
||||||
|
}
|
||||||
return &Payload{}
|
return &Payload{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
4
internal/outpost/radius/eap/protocol/peap/state.go
Normal file
4
internal/outpost/radius/eap/protocol/peap/state.go
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
package peap
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
}
|
||||||
@ -13,9 +13,16 @@ func (p *Payload) innerHandler(ctx protocol.Context) {
|
|||||||
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
ctx.EndInnerProtocol(protocol.StatusError, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pl := p.Inner.Handle(ctx)
|
pl := p.Inner.Handle(ctx.ForInnerProtocol(p.Inner.Type()))
|
||||||
enc, err := pl.Encode()
|
enc, err := pl.Encode()
|
||||||
p.st.TLS.Write(enc)
|
if err != nil {
|
||||||
|
ctx.Log().WithError(err).Warning("failed to encode inner protocol")
|
||||||
|
}
|
||||||
|
// p.st.Conn.expectedWriterByteCount = len(enc)
|
||||||
|
_, err = p.st.TLS.Write(enc)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log().WithError(err).Warning("failed to write to TLS")
|
||||||
|
}
|
||||||
// return &Payload{
|
// return &Payload{
|
||||||
// Data: enc,
|
// Data: enc,
|
||||||
// }
|
// }
|
||||||
|
|||||||
@ -36,12 +36,16 @@ type Payload struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) Type() protocol.Type {
|
func (p *Payload) Type() protocol.Type {
|
||||||
if p.Inner != nil {
|
// if p.inner != nil {
|
||||||
return p.Inner.Type()
|
// return p.inner.Type()
|
||||||
}
|
// }
|
||||||
return TypeTLS
|
return TypeTLS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Payload) HasInner() protocol.Payload {
|
||||||
|
return p.Inner
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Payload) Offerable() bool {
|
func (p *Payload) Offerable() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -58,7 +62,7 @@ func (p *Payload) Decode(raw []byte) error {
|
|||||||
} else {
|
} else {
|
||||||
p.Data = raw[0:]
|
p.Data = raw[0:]
|
||||||
}
|
}
|
||||||
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw")
|
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Trace("TLS: decode raw")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user