diff --git a/internal/outpost/radius/eap/context.go b/internal/outpost/radius/eap/context.go index e680b15cb0..e32f8f7f2d 100644 --- a/internal/outpost/radius/eap/context.go +++ b/internal/outpost/radius/eap/context.go @@ -9,6 +9,7 @@ import ( type context struct { req *radius.Request state interface{} + typeState map[protocol.Type]any log *log.Entry settings interface{} endStatus protocol.Status @@ -23,6 +24,10 @@ func (ctx context) ProtocolSettings() interface{} { return ctx.settings } +func (ctx *context) StateForProtocol(p protocol.Type) interface{} { + return ctx.typeState[p] +} + func (ctx *context) GetProtocolState() interface{} { return ctx.state } diff --git a/internal/outpost/radius/eap/handler.go b/internal/outpost/radius/eap/handler.go index e55d441256..8209c4a2da 100644 --- a/internal/outpost/radius/eap/handler.go +++ b/internal/outpost/radius/eap/handler.go @@ -9,6 +9,7 @@ import ( "github.com/gorilla/securecookie" log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/radius/eap/protocol" + "goauthentik.io/internal/outpost/radius/eap/protocol/eap" "goauthentik.io/internal/outpost/radius/eap/protocol/legacy_nak" "layeh.com/radius" "layeh.com/radius/rfc2865" @@ -30,16 +31,20 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) } p.state = rst - rp, err := p.handleInner(r) + rep, err := p.handleInner(r) + rp := &Packet{ + eap: rep, + } + rres := r.Response(radius.CodeAccessReject) if err == nil { rres = p.endModifier(rres) - switch rp.code { - case CodeRequest: + switch rp.eap.Code { + case protocol.CodeRequest: rres.Code = radius.CodeAccessChallenge - case CodeFailure: + case protocol.CodeFailure: rres.Code = radius.CodeAccessReject - case CodeSuccess: + case protocol.CodeSuccess: rres.Code = radius.CodeAccessAccept } } else { @@ -54,7 +59,7 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) sendErrorResponse(w, r) return } - log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.Payload)).Debug("EAP: encapsulated challenge") + log.WithField("length", len(eapEncoded)).WithField("type", fmt.Sprintf("%T", rp.eap.Payload)).Debug("EAP: encapsulated challenge") rfc2869.EAPMessage_Set(rres, eapEncoded) err = p.setMessageAuthenticator(rres) if err != nil { @@ -68,40 +73,50 @@ func (p *Packet) HandleRadiusPacket(w radius.ResponseWriter, r *radius.Request) } } -func (p *Packet) handleInner(r *radius.Request) (*Packet, error) { +func (p *Packet) handleInner(r *radius.Request) (*eap.Payload, error) { st := p.stm.GetEAPState(p.state) if st == nil { log.Debug("EAP: blank state") st = BlankState(p.stm.GetEAPSettings()) } + // FIXME: Statically call Handle of root EAP packet to make its data accessible + ectx := &context{ + state: st.TypeState[eap.TypeEAP], + log: log.WithField("type", fmt.Sprintf("%T", &eap.Payload{})), + } + p.eap.Handle(ectx) + st.TypeState[eap.TypeEAP] = ectx.GetProtocolState() + p.stm.SetEAPState(p.state, st) + nextChallengeToOffer, err := st.GetNextProtocol() if err != nil { - return &Packet{ - code: CodeFailure, - id: p.id, + return &eap.Payload{ + Code: protocol.CodeFailure, + ID: p.eap.ID, }, err } - next := func() (*Packet, error) { + next := func() (*eap.Payload, error) { st.ProtocolIndex += 1 p.stm.SetEAPState(p.state, st) return p.handleInner(r) } - if _, ok := p.Payload.(*legacy_nak.Payload); ok { + if _, ok := p.eap.Payload.(*legacy_nak.Payload); ok { log.Debug("EAP: received NAK, trying next protocol") - p.Payload = nil + p.eap.Payload = nil return next() } np, _ := emptyPayload(p.stm, nextChallengeToOffer) ctx := &context{ - req: r, - state: st.TypeState[np.Type()], - log: log.WithField("type", fmt.Sprintf("%T", np)), - settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()], + req: r, + state: st.TypeState[np.Type()], + typeState: st.TypeState, + log: log.WithField("type", fmt.Sprintf("%T", np)), + settings: p.stm.GetEAPSettings().ProtocolSettings[np.Type()], } if !np.Offerable() { ctx.log.Debug("EAP: protocol not offerable, skipping") @@ -119,11 +134,11 @@ func (p *Packet) handleInner(r *radius.Request) (*Packet, error) { switch ctx.endStatus { case protocol.StatusSuccess: - res.code = CodeSuccess - res.id -= 1 + res.Code = protocol.CodeSuccess + res.ID -= 1 case protocol.StatusError: - res.code = CodeFailure - res.id -= 1 + res.Code = protocol.CodeFailure + res.ID -= 1 case protocol.StatusNextProtocol: ctx.log.Debug("EAP: Protocol ended, starting next protocol") return next() @@ -132,18 +147,18 @@ func (p *Packet) handleInner(r *radius.Request) (*Packet, error) { return res, nil } -func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *Packet { - res := &Packet{ - code: CodeRequest, - id: p.id + 1, - msgType: np.Type(), +func (p *Packet) GetChallengeForType(ctx *context, np protocol.Payload) *eap.Payload { + res := &eap.Payload{ + Code: protocol.CodeRequest, + ID: p.eap.ID + 1, + MsgType: np.Type(), } var payload any if ctx.IsProtocolStart() { - p.Payload = np - p.Payload.Decode(p.rawPayload) + p.eap.Payload = np + p.eap.Payload.Decode(p.eap.RawPayload) } - payload = p.Payload.Handle(ctx) + payload = p.eap.Payload.Handle(ctx) if payload != nil { res.Payload = payload.(protocol.Payload) } diff --git a/internal/outpost/radius/eap/packet.go b/internal/outpost/radius/eap/packet.go index 46e6b7bf4c..27297cb159 100644 --- a/internal/outpost/radius/eap/packet.go +++ b/internal/outpost/radius/eap/packet.go @@ -1,40 +1,20 @@ package eap import ( - "encoding/binary" - "errors" "fmt" - log "github.com/sirupsen/logrus" - "goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/protocol" + "goauthentik.io/internal/outpost/radius/eap/protocol/eap" "layeh.com/radius" ) -type Code uint8 - -const ( - CodeRequest Code = 1 - CodeResponse Code = 2 - CodeSuccess Code = 3 - CodeFailure Code = 4 -) - type Packet struct { - code Code - id uint8 - length uint16 - msgType protocol.Type - rawPayload []byte - Payload protocol.Payload - + eap *eap.Payload stm StateManager state string endModifier func(p *radius.Packet) *radius.Packet } -type PayloadWriter struct{} - func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) { for _, cons := range stm.GetEAPSettings().Protocols { if np := cons(); np.Type() == t { @@ -46,28 +26,24 @@ func emptyPayload(stm StateManager, t protocol.Type) (protocol.Payload, error) { func Decode(stm StateManager, raw []byte) (*Packet, error) { packet := &Packet{ + eap: &eap.Payload{}, stm: stm, endModifier: func(p *radius.Packet) *radius.Packet { return p }, } - packet.code = Code(raw[0]) - packet.id = raw[1] - packet.length = binary.BigEndian.Uint16(raw[2:]) - if packet.length != uint16(len(raw)) { - return nil, errors.New("mismatched packet length") - } - if len(raw) > 4 && (packet.code == CodeRequest || packet.code == CodeResponse) { - packet.msgType = protocol.Type(raw[4]) - } - p, err := emptyPayload(stm, packet.msgType) + // FIXME: We're decoding twice here, first to get the msg type, then come back to assign the payload type + // then re-parse to parse the payload correctly + err := packet.eap.Decode(raw) if err != nil { return nil, err } - packet.Payload = p - packet.rawPayload = raw[5:] - log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw") - err = packet.Payload.Decode(raw[5:]) + p, err := emptyPayload(stm, packet.eap.MsgType) + if err != nil { + return nil, err + } + packet.eap.Payload = p + err = packet.eap.Decode(raw) if err != nil { return nil, err } @@ -75,20 +51,5 @@ func Decode(stm StateManager, raw []byte) (*Packet, error) { } func (p *Packet) Encode() ([]byte, error) { - buff := make([]byte, 4) - buff[0] = uint8(p.code) - buff[1] = uint8(p.id) - - if p.Payload != nil { - payloadBuffer, err := p.Payload.Encode() - if err != nil { - return buff, err - } - if p.code == CodeRequest || p.code == CodeResponse { - buff = append(buff, uint8(p.msgType)) - } - buff = append(buff, payloadBuffer...) - } - binary.BigEndian.PutUint16(buff[2:], uint16(len(buff))) - return buff, nil + return p.eap.Encode() } diff --git a/internal/outpost/radius/eap/protocol/context.go b/internal/outpost/radius/eap/protocol/context.go index f08b18df4f..58791ec7ea 100644 --- a/internal/outpost/radius/eap/protocol/context.go +++ b/internal/outpost/radius/eap/protocol/context.go @@ -18,6 +18,8 @@ type Context interface { Packet() *radius.Request ProtocolSettings() interface{} + + StateForProtocol(p Type) interface{} GetProtocolState() interface{} SetProtocolState(interface{}) diff --git a/internal/outpost/radius/eap/protocol/eap/payload.go b/internal/outpost/radius/eap/protocol/eap/payload.go new file mode 100644 index 0000000000..d179a1aa81 --- /dev/null +++ b/internal/outpost/radius/eap/protocol/eap/payload.go @@ -0,0 +1,83 @@ +package eap + +import ( + "encoding/binary" + "errors" + "fmt" + + log "github.com/sirupsen/logrus" + "goauthentik.io/internal/outpost/radius/eap/debug" + "goauthentik.io/internal/outpost/radius/eap/protocol" +) + +const TypeEAP protocol.Type = 0 + +func Protocol() protocol.Payload { + return &Payload{} +} + +type Payload struct { + Code protocol.Code + ID uint8 + Length uint16 + MsgType protocol.Type + Payload protocol.Payload + RawPayload []byte +} + +func (ip *Payload) Type() protocol.Type { + return TypeEAP +} + +func (ip *Payload) Offerable() bool { + return false +} + +func (packet *Payload) Decode(raw []byte) error { + packet.Code = protocol.Code(raw[0]) + packet.ID = raw[1] + packet.Length = binary.BigEndian.Uint16(raw[2:]) + if packet.Length != uint16(len(raw)) { + return errors.New("mismatched packet length") + } + if len(raw) > 4 && (packet.Code == protocol.CodeRequest || packet.Code == protocol.CodeResponse) { + packet.MsgType = protocol.Type(raw[4]) + } + packet.RawPayload = raw[5:] + if packet.Payload == nil { + return nil + } + log.WithField("raw", debug.FormatBytes(raw)).WithField("payload", fmt.Sprintf("%T", packet.Payload)).Debug("EAP: decode raw") + err := packet.Payload.Decode(raw[5:]) + if err != nil { + return err + } + return nil +} + +func (p *Payload) Encode() ([]byte, error) { + buff := make([]byte, 4) + buff[0] = uint8(p.Code) + buff[1] = uint8(p.ID) + + if p.Payload != nil { + payloadBuffer, err := p.Payload.Encode() + if err != nil { + return buff, err + } + if p.Code == protocol.CodeRequest || p.Code == protocol.CodeResponse { + buff = append(buff, uint8(p.MsgType)) + } + buff = append(buff, payloadBuffer...) + } + binary.BigEndian.PutUint16(buff[2:], uint16(len(buff))) + return buff, nil +} + +func (ip *Payload) Handle(ctx protocol.Context) protocol.Payload { + ctx.Log().Debug("EAP: Handle") + ctx.SetProtocolState(&State{ + PacketID: ip.ID, + }) + return nil +} diff --git a/internal/outpost/radius/eap/protocol/eap/state.go b/internal/outpost/radius/eap/protocol/eap/state.go new file mode 100644 index 0000000000..a2ea134ca6 --- /dev/null +++ b/internal/outpost/radius/eap/protocol/eap/state.go @@ -0,0 +1,5 @@ +package eap + +type State struct { + PacketID uint8 +} diff --git a/internal/outpost/radius/eap/protocol/packet.go b/internal/outpost/radius/eap/protocol/packet.go index 8b698d768e..6da6c88c93 100644 --- a/internal/outpost/radius/eap/protocol/packet.go +++ b/internal/outpost/radius/eap/protocol/packet.go @@ -9,3 +9,12 @@ type Payload interface { } type Type uint8 + +type Code uint8 + +const ( + CodeRequest Code = 1 + CodeResponse Code = 2 + CodeSuccess Code = 3 + CodeFailure Code = 4 +) diff --git a/internal/outpost/radius/eap/protocol/peap/payload.go b/internal/outpost/radius/eap/protocol/peap/payload.go index 6347dcaea0..eb3b960c60 100644 --- a/internal/outpost/radius/eap/protocol/peap/payload.go +++ b/internal/outpost/radius/eap/protocol/peap/payload.go @@ -4,6 +4,8 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/radius/eap/debug" "goauthentik.io/internal/outpost/radius/eap/protocol" + "goauthentik.io/internal/outpost/radius/eap/protocol/eap" + "goauthentik.io/internal/outpost/radius/eap/protocol/identity" "goauthentik.io/internal/outpost/radius/eap/protocol/tls" ) @@ -11,11 +13,14 @@ const TypePEAP protocol.Type = 25 func Protocol() protocol.Payload { return &tls.Payload{ - Inner: &Payload{}, + Inner: &Payload{ + Inner: &eap.Payload{}, + }, } } type Payload struct { + Inner protocol.Payload } func (p *Payload) Type() protocol.Type { @@ -33,7 +38,16 @@ func (p *Payload) Encode() ([]byte, error) { } func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { - log.Debug("PEAP: Handle") + eapState := ctx.StateForProtocol(eap.TypeEAP).(*eap.State) + if !ctx.IsProtocolStart() { + ctx.Log().Debug("PEAP: Protocol start") + return &eap.Payload{ + Code: protocol.CodeRequest, + ID: eapState.PacketID, + MsgType: identity.TypeIdentity, + Payload: &identity.Payload{}, + } + } return &Payload{} } diff --git a/internal/outpost/radius/eap/protocol/tls/inner.go b/internal/outpost/radius/eap/protocol/tls/inner.go index 361458035d..b9d4bcead6 100644 --- a/internal/outpost/radius/eap/protocol/tls/inner.go +++ b/internal/outpost/radius/eap/protocol/tls/inner.go @@ -4,18 +4,19 @@ import ( "goauthentik.io/internal/outpost/radius/eap/protocol" ) -func (p *Payload) innerHandler(ctx protocol.Context) *Payload { +func (p *Payload) innerHandler(ctx protocol.Context) { // p.st.TLS.read // d, _ := io.ReadAll(p.st.TLS) err := p.Inner.Decode([]byte{}) if err != nil { ctx.Log().WithError(err).Warning("TLS: failed to decode inner protocol") ctx.EndInnerProtocol(protocol.StatusError, nil) - return nil + return } pl := p.Inner.Handle(ctx) enc, err := pl.Encode() - return &Payload{ - Data: enc, - } + p.st.TLS.Write(enc) + // return &Payload{ + // Data: enc, + // } } diff --git a/internal/outpost/radius/eap/protocol/tls/payload.go b/internal/outpost/radius/eap/protocol/tls/payload.go index 107f05ff3b..b26e02ea81 100644 --- a/internal/outpost/radius/eap/protocol/tls/payload.go +++ b/internal/outpost/radius/eap/protocol/tls/payload.go @@ -124,7 +124,8 @@ func (p *Payload) Handle(ctx protocol.Context) protocol.Payload { defer p.st.ContextCancel() if p.Inner != nil { ctx.Log().Debug("TLS: Handshake is done, delegating to inner protocol") - return p.innerHandler(ctx) + p.innerHandler(ctx) + return p.startChunkedTransfer(p.st.Conn.OutboundData()) } // If we don't have a final status from the handshake finished function, stall for time pst, _ := retry.DoWithData(