try to make this work
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -16,28 +16,27 @@ type context struct { | |||||||
| 	endModifier func(p *radius.Packet) *radius.Packet | 	endModifier func(p *radius.Packet) *radius.Packet | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx context) Packet() *radius.Request { | func (ctx *context) Packet() *radius.Request                      { return ctx.req } | ||||||
| 	return ctx.req | func (ctx *context) ProtocolSettings() interface{}                { return ctx.settings } | ||||||
| } | func (ctx *context) StateForProtocol(p protocol.Type) interface{} { return ctx.typeState[p] } | ||||||
|  | func (ctx *context) GetProtocolState() interface{}                { return ctx.state } | ||||||
|  | func (ctx *context) SetProtocolState(st interface{})              { ctx.state = st } | ||||||
|  | func (ctx *context) IsProtocolStart() bool                        { return ctx.state == nil } | ||||||
|  | func (ctx *context) Log() *log.Entry                              { return ctx.log } | ||||||
|  |  | ||||||
| func (ctx context) ProtocolSettings() interface{} { | func (ctx *context) ForInnerProtocol(p protocol.Type) protocol.Context { | ||||||
| 	return ctx.settings | 	log.Debug("foo") | ||||||
| } | 	log.Debugf("%+v", ctx.typeState[protocol.Type(13)]) | ||||||
|  | 	log.Debugf("%+v", ctx.typeState[protocol.Type(25)]) | ||||||
| func (ctx *context) StateForProtocol(p protocol.Type) interface{} { | 	return &context{ | ||||||
| 	return ctx.typeState[p] | 		req:         ctx.req, | ||||||
| } | 		state:       ctx.StateForProtocol(p), | ||||||
|  | 		typeState:   ctx.typeState, | ||||||
| func (ctx *context) GetProtocolState() interface{} { | 		log:         ctx.log, | ||||||
| 	return ctx.state | 		settings:    ctx.settings, | ||||||
| } | 		endStatus:   ctx.endStatus, | ||||||
|  | 		endModifier: ctx.endModifier, | ||||||
| func (ctx *context) SetProtocolState(st interface{}) { | 	} | ||||||
| 	ctx.state = st |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ctx *context) IsProtocolStart() bool { |  | ||||||
| 	return ctx.state == nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { | func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packet) *radius.Packet) { | ||||||
| @ -52,7 +51,3 @@ func (ctx *context) EndInnerProtocol(st protocol.Status, mf func(p *radius.Packe | |||||||
| 	} | 	} | ||||||
| 	ctx.endModifier = mf | 	ctx.endModifier = mf | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx context) Log() *log.Entry { |  | ||||||
| 	return ctx.log |  | ||||||
| } |  | ||||||
|  | |||||||
| @ -97,8 +97,9 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { | |||||||
| 		}, err | 		}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	next := func() (*eap.Payload, error) { | 	next := func(oldProtocol protocol.Type) (*eap.Payload, error) { | ||||||
| 		st.ProtocolIndex += 1 | 		st.ProtocolIndex += 1 | ||||||
|  | 		delete(st.TypeState, oldProtocol) | ||||||
| 		p.stm.SetEAPState(p.state, st) | 		p.stm.SetEAPState(p.state, st) | ||||||
| 		return p.handleInner(r) | 		return p.handleInner(r) | ||||||
| 	} | 	} | ||||||
| @ -106,26 +107,28 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { | |||||||
| 	if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok { | 	if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok { | ||||||
| 		log.Debug("EAP: received NAK, trying next protocol") | 		log.Debug("EAP: received NAK, trying next protocol") | ||||||
| 		p.eap.Payload = nil | 		p.eap.Payload = nil | ||||||
| 		return next() | 		log.Debug(st.ProtocolPriority[st.ProtocolIndex]) | ||||||
|  | 		return next(st.ProtocolPriority[st.ProtocolIndex]) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	np, _ := emptyPayload(p.stm, nextChallengeToOffer) | 	np, t, _ := emptyPayload(p.stm, nextChallengeToOffer) | ||||||
|  |  | ||||||
| 	ctx := &context{ | 	ctx := &context{ | ||||||
| 		req: r, | 		req: r, | ||||||
|  | 		// Always write to the state of the outer protocol | ||||||
| 		state:     st.TypeState[np.Type()], | 		state:     st.TypeState[np.Type()], | ||||||
| 		typeState: st.TypeState, | 		typeState: st.TypeState, | ||||||
| 		log:       log.WithField("type", fmt.Sprintf("%T", np)), | 		log:       log.WithField("type", fmt.Sprintf("%T", np)).WithField("code", t), | ||||||
| 		settings:  p.stm.GetEAPSettings().ProtocolSettings[np.Type()], | 		settings:  p.stm.GetEAPSettings().ProtocolSettings[t], | ||||||
| 	} | 	} | ||||||
| 	if !np.Offerable() { | 	if !np.Offerable() { | ||||||
| 		ctx.log.Debug("EAP: protocol not offerable, skipping") | 		ctx.log.Debug("EAP: protocol not offerable, skipping") | ||||||
| 		return next() | 		return next(np.Type()) | ||||||
| 	} | 	} | ||||||
| 	ctx.log.Debug("EAP: Passing to protocol") | 	ctx.log.Debug("EAP: Passing to protocol") | ||||||
|  |  | ||||||
| 	res := p.GetChallengeForType(ctx, np) | 	res := p.GetChallengeForType(ctx, np, t) | ||||||
| 	st.TypeState[np.Type()] = ctx.GetProtocolState() | 	st.TypeState[t] = ctx.GetProtocolState() | ||||||
| 	p.stm.SetEAPState(p.state, st) | 	p.stm.SetEAPState(p.state, st) | ||||||
|  |  | ||||||
| 	if ctx.endModifier != nil { | 	if ctx.endModifier != nil { | ||||||
| @ -141,17 +144,17 @@ func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { | |||||||
| 		res.ID -= 1 | 		res.ID -= 1 | ||||||
| 	case protocol.StatusNextProtocol: | 	case protocol.StatusNextProtocol: | ||||||
| 		ctx.log.Debug("EAP: Protocol ended, starting next protocol") | 		ctx.log.Debug("EAP: Protocol ended, starting next protocol") | ||||||
| 		return next() | 		return next(np.Type()) | ||||||
| 	case protocol.StatusUnknown: | 	case protocol.StatusUnknown: | ||||||
| 	} | 	} | ||||||
| 	return res, nil | 	return res, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *eap.Payload { | func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload, t protocol.Type) *eap.Payload { | ||||||
| 	res := &eap.Payload{ | 	res := &eap.Payload{ | ||||||
| 		Code:    protocol.CodeRequest, | 		Code:    protocol.CodeRequest, | ||||||
| 		ID:      p.eap.ID + 1, | 		ID:      p.eap.ID + 1, | ||||||
| 		MsgType: np.Type(), | 		MsgType: t, | ||||||
| 	} | 	} | ||||||
| 	var payload any | 	var payload any | ||||||
| 	if ctx.IsProtocolStart() { | 	if ctx.IsProtocolStart() { | ||||||
|  | |||||||
| @ -15,13 +15,20 @@ type Packet struct { | |||||||
| 	endModifier func(p *radius.Packet) *radius.Packet | 	endModifier func(p *radius.Packet) *radius.Packet | ||||||
| } | } | ||||||
|  |  | ||||||
| func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) { | func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, protocol.Type, error) { | ||||||
| 	for _, cons := range stm.GetEAPSettings().Protocols { | 	for _, cons := range stm.GetEAPSettings().Protocols { | ||||||
| 		if np := cons(); np.Type() == t { | 		np := cons() | ||||||
| 			return np, nil | 		if np.Type() == t { | ||||||
|  | 			return np, np.Type(), nil | ||||||
|  | 		} | ||||||
|  | 		// If the protocol has an inner protocol, return the original type but the code for the inner protocol | ||||||
|  | 		if i, ok := np.(protocol.Inner); ok { | ||||||
|  | 			if ii := i.HasInner(); ii != nil { | ||||||
|  | 				return np, ii.Type(), nil | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	return nil, fmt.Errorf("unsupported EAP type %d", t) | 	} | ||||||
|  | 	return nil, protocol.Type(0), fmt.Errorf("unsupported EAP type %d", t) | ||||||
| } | } | ||||||
|  |  | ||||||
| func Decode(stm StateManager, raw []byte) (*Packet, error) { | func Decode(stm StateManager, raw []byte) (*Packet, error) { | ||||||
| @ -38,7 +45,7 @@ func Decode(stm StateManager, raw []byte) (*Packet, error) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	p, err := emptyPayload(stm, packet.eap.MsgType) | 	p, _, err := emptyPayload(stm, packet.eap.MsgType) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -19,6 +19,8 @@ type Context interface { | |||||||
|  |  | ||||||
| 	ProtocolSettings() interface{} | 	ProtocolSettings() interface{} | ||||||
|  |  | ||||||
|  | 	ForInnerProtocol(p Type) Context | ||||||
|  |  | ||||||
| 	StateForProtocol(p Type) interface{} | 	StateForProtocol(p Type) interface{} | ||||||
| 	GetProtocolState() interface{} | 	GetProtocolState() interface{} | ||||||
| 	SetProtocolState(interface{}) | 	SetProtocolState(interface{}) | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ func (packet *Payload) Decode(raw []byte) error { | |||||||
| 	if packet.Payload == nil { | 	if packet.Payload == nil { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw") | 	log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Trace("EAP: decode raw") | ||||||
| 	err := packet.Payload.Decode(raw[5:]) | 	err := packet.Payload.Decode(raw[5:]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
|  | |||||||
| @ -8,6 +8,10 @@ type Payload interface { | |||||||
| 	Offerable() bool | 	Offerable() bool | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type Inner interface { | ||||||
|  | 	HasInner() Payload | ||||||
|  | } | ||||||
|  |  | ||||||
| type Type uint8 | type Type uint8 | ||||||
|  |  | ||||||
| type Code uint8 | type Code uint8 | ||||||
|  | |||||||
| @ -21,33 +21,55 @@ func Protocol() protocol.Payload { | |||||||
|  |  | ||||||
| type Payload struct { | type Payload struct { | ||||||
| 	Inner protocol.Payload | 	Inner protocol.Payload | ||||||
|  |  | ||||||
|  | 	eap *eap.Payload | ||||||
|  | 	st  *State | ||||||
|  | 	raw []byte | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Payload) Type() protocol.Type { | func (p *Payload) Type() protocol.Type { | ||||||
| 	return TypePEAP | 	return TypePEAP | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (p *Payload) HasInner() protocol.Payload { | ||||||
|  | 	return p.Inner | ||||||
|  | } | ||||||
|  |  | ||||||
| func (p *Payload) Decode(raw []byte) error { | func (p *Payload) Decode(raw []byte) error { | ||||||
| 	log.WithField("raw", debug.FormatBytes(raw)).Debug("PEAP: Decode") | 	log.WithField("raw", debug.FormatBytes(raw)).Debug("PEAP: Decode") | ||||||
|  | 	p.raw = raw | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Payload) Encode() ([]byte, error) { | func (p *Payload) Encode() ([]byte, error) { | ||||||
| 	log.Debug("PEAP: Encode") | 	return p.eap.Encode() | ||||||
| 	return []byte{}, nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { | func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { | ||||||
|  | 	defer func() { | ||||||
|  | 		ctx.SetProtocolState(p.st) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
| 	eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State) | 	eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State) | ||||||
| 	if !ctx.IsProtocolStart() { |  | ||||||
|  | 	if ctx.IsProtocolStart() { | ||||||
| 		ctx.Log().Debug("PEAP: Protocol start") | 		ctx.Log().Debug("PEAP: Protocol start") | ||||||
|  | 		p.st = &State{} | ||||||
| 		return &eap.Payload{ | 		return &eap.Payload{ | ||||||
| 			Code:    protocol.CodeRequest, | 			Code:    protocol.CodeRequest, | ||||||
| 			ID:      eapState.PacketID, | 			ID:      eapState.PacketID + 1, | ||||||
| 			MsgType: identity.TypeIdentity, | 			MsgType: identity.TypeIdentity, | ||||||
| 			Payload: &identity.Payload{}, | 			Payload: &identity.Payload{}, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	p.st = ctx.GetProtocolState().(*State) | ||||||
|  |  | ||||||
|  | 	ep := &eap.Payload{} | ||||||
|  | 	err := ep.Decode(p.raw) | ||||||
|  | 	if err != nil { | ||||||
|  | 		ctx.Log().WithError(err).Warning("PEAP: failed to decode inner EAP") | ||||||
|  | 		return &Payload{} | ||||||
|  | 	} | ||||||
| 	return &Payload{} | 	return &Payload{} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								internal/outpost/radius/eap/protocol/peap/state.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								internal/outpost/radius/eap/protocol/peap/state.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | package peap | ||||||
|  |  | ||||||
|  | type State struct { | ||||||
|  | } | ||||||
| @ -13,9 +13,16 @@ func (p *Payload) innerHandler(ctx protocol.Context) { | |||||||
| 		ctx.EndInnerProtocol(protocol.StatusError, nil) | 		ctx.EndInnerProtocol(protocol.StatusError, nil) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	pl := p.Inner.Handle(ctx) | 	pl := p.Inner.Handle(ctx.ForInnerProtocol(p.Inner.Type())) | ||||||
| 	enc, err := pl.Encode() | 	enc, err := pl.Encode() | ||||||
| 	p.st.TLS.Write(enc) | 	if err != nil { | ||||||
|  | 		ctx.Log().WithError(err).Warning("failed to encode inner protocol") | ||||||
|  | 	} | ||||||
|  | 	// p.st.Conn.expectedWriterByteCount = len(enc) | ||||||
|  | 	_, err = p.st.TLS.Write(enc) | ||||||
|  | 	if err != nil { | ||||||
|  | 		ctx.Log().WithError(err).Warning("failed to write to TLS") | ||||||
|  | 	} | ||||||
| 	// return &Payload{ | 	// return &Payload{ | ||||||
| 	// 	Data: enc, | 	// 	Data: enc, | ||||||
| 	// } | 	// } | ||||||
|  | |||||||
| @ -36,12 +36,16 @@ type Payload struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Payload) Type() protocol.Type { | func (p *Payload) Type() protocol.Type { | ||||||
| 	if p.Inner != nil { | 	// if p.inner != nil { | ||||||
| 		return p.Inner.Type() | 	// 	return p.inner.Type() | ||||||
| 	} | 	// } | ||||||
| 	return TypeTLS | 	return TypeTLS | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (p *Payload) HasInner() protocol.Payload { | ||||||
|  | 	return p.Inner | ||||||
|  | } | ||||||
|  |  | ||||||
| func (p *Payload) Offerable() bool { | func (p *Payload) Offerable() bool { | ||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
| @ -58,7 +62,7 @@ func (p *Payload) Decode(raw []byte) error { | |||||||
| 	} else { | 	} else { | ||||||
| 		p.Data = raw[0:] | 		p.Data = raw[0:] | ||||||
| 	} | 	} | ||||||
| 	log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw") | 	log.WithField("raw", debug.FormatBytes(p.Data)).WithField("size", len(p.Data)).WithField("flags", p.Flags).Trace("TLS: decode raw") | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer