29
internal/outpost/radius/eap/context.go
Normal file
29
internal/outpost/radius/eap/context.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package eap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"layeh.com/radius"
|
||||||
|
)
|
||||||
|
|
||||||
|
type context[TState any, TSettings any] struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx context[TState, TSettings]) ProtocolSettings() TSettings {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx context[TState, TSettings]) GetProtocolState(def func(context[TState, TSettings]) TState) TState {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx context[TState, TSettings]) SetProtocolState(TState) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx context[TState, TSettings]) EndInnerProtocol(func(p *radius.Packet) *radius.Packet) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx context[TState, TSettings]) Log() *logrus.Entry {
|
||||||
|
return nil
|
||||||
|
}
|
@ -28,7 +28,10 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
|||||||
panic("No more challenges")
|
panic("No more challenges")
|
||||||
}
|
}
|
||||||
nextChallengeToOffer := st.ChallengesToOffer[0]
|
nextChallengeToOffer := st.ChallengesToOffer[0]
|
||||||
res, newState := p.GetChallengeForType(st, nextChallengeToOffer)
|
|
||||||
|
ctx := context{}
|
||||||
|
|
||||||
|
res, newState := p.GetChallengeForType(ctx, nextChallengeToOffer)
|
||||||
stm.SetEAPState(rst, newState)
|
stm.SetEAPState(rst, newState)
|
||||||
|
|
||||||
rres := r.Response(radius.CodeAccessChallenge)
|
rres := r.Response(radius.CodeAccessChallenge)
|
||||||
@ -52,21 +55,22 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) GetChallengeForType(st *State, t Type) (*Packet, *State) {
|
func (p *Packet) GetChallengeForType(ctx context[any, any], t Type) *Packet {
|
||||||
res := &Packet{
|
res := &Packet{
|
||||||
code: CodeRequest,
|
code: CodeRequest,
|
||||||
id: p.id + 1,
|
id: p.id + 1,
|
||||||
msgType: t,
|
msgType: t,
|
||||||
}
|
}
|
||||||
var payload any
|
var payload any
|
||||||
var tst any
|
|
||||||
switch t {
|
switch t {
|
||||||
case TypeTLS:
|
case TypeTLS:
|
||||||
|
// TODO: rewrite this
|
||||||
if _, ok := p.Payload.(*tls.Payload); !ok {
|
if _, ok := p.Payload.(*tls.Payload); !ok {
|
||||||
p.Payload = &tls.Payload{}
|
p.Payload = &tls.Payload{}
|
||||||
p.Payload.Decode(p.rawPayload)
|
p.Payload.Decode(p.rawPayload)
|
||||||
}
|
}
|
||||||
payload, tst = p.Payload.(*tls.Payload).Handle(st.TypeState[t])
|
// this
|
||||||
|
payload = p.Payload.(*tls.Payload).Handle(ctx)
|
||||||
}
|
}
|
||||||
st.TypeState[t] = tst
|
st.TypeState[t] = tst
|
||||||
res.Payload = payload.(protocol.Payload)
|
res.Payload = payload.(protocol.Payload)
|
||||||
|
18
internal/outpost/radius/eap/protocol/context.go
Normal file
18
internal/outpost/radius/eap/protocol/context.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"layeh.com/radius"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Context[TState any, TSettings any] interface {
|
||||||
|
// GlobalState()
|
||||||
|
|
||||||
|
ProtocolSettings() TSettings
|
||||||
|
GetProtocolState(def func(Context[TState, TSettings]) TState) TState
|
||||||
|
SetProtocolState(TState)
|
||||||
|
|
||||||
|
EndInnerProtocol(func(p *radius.Packet) *radius.Packet)
|
||||||
|
|
||||||
|
Log() *log.Entry
|
||||||
|
}
|
@ -3,8 +3,8 @@ package eap
|
|||||||
import "slices"
|
import "slices"
|
||||||
|
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
ChallengesToOffer []Type
|
ProtocolsToOffer []Type
|
||||||
ChallengeSettings map[Type]interface{}
|
ProtocolSettings map[Type]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateManager interface {
|
type StateManager interface {
|
||||||
@ -20,7 +20,7 @@ type State struct {
|
|||||||
|
|
||||||
func BlankState(settings Settings) *State {
|
func BlankState(settings Settings) *State {
|
||||||
return &State{
|
return &State{
|
||||||
ChallengesToOffer: slices.Clone(settings.ChallengesToOffer),
|
ChallengesToOffer: slices.Clone(settings.ProtocolsToOffer),
|
||||||
TypeState: map[Type]any{},
|
TypeState: map[Type]any{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,31 +12,18 @@ import (
|
|||||||
"goauthentik.io/internal/outpost/radius/eap/debug"
|
"goauthentik.io/internal/outpost/radius/eap/debug"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
"goauthentik.io/internal/outpost/radius/eap/protocol"
|
||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
"layeh.com/radius/rfc2865"
|
|
||||||
"layeh.com/radius/vendors/microsoft"
|
"layeh.com/radius/vendors/microsoft"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxChunkSize = 1000
|
const maxChunkSize = 1000
|
||||||
const staleConnectionTimeout = 10
|
const staleConnectionTimeout = 10
|
||||||
|
|
||||||
var certs = []tls.Certificate{}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Testing
|
|
||||||
cert, err := tls.LoadX509KeyPair(
|
|
||||||
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
|
|
||||||
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
certs = append(certs, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Payload struct {
|
type Payload struct {
|
||||||
Flags Flag
|
Flags Flag
|
||||||
Length uint32
|
Length uint32
|
||||||
Data []byte
|
Data []byte
|
||||||
|
|
||||||
|
st *State
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) Decode(raw []byte) error {
|
func (p *Payload) Decode(raw []byte) error {
|
||||||
@ -74,92 +61,85 @@ func (p *Payload) Encode() ([]byte, error) {
|
|||||||
return buff, nil
|
return buff, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
type tctx = protocol.Context[*State, Settings]
|
||||||
if stt == nil {
|
|
||||||
log.Debug("TLS: new state")
|
func (p *Payload) Handle(ctx tctx) protocol.Payload {
|
||||||
stt = NewState()
|
p.st = ctx.GetProtocolState(NewState)
|
||||||
}
|
defer ctx.SetProtocolState(p.st)
|
||||||
st := stt.(*State)
|
if !p.st.HasStarted {
|
||||||
if !st.HasStarted {
|
|
||||||
log.Debug("TLS: handshake starting")
|
log.Debug("TLS: handshake starting")
|
||||||
st.HasStarted = true
|
p.st.HasStarted = true
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: FlagTLSStart,
|
Flags: FlagTLSStart,
|
||||||
}, st
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if st.TLS == nil {
|
if p.st.TLS == nil {
|
||||||
st = p.tlsInit(st)
|
p.tlsInit(ctx)
|
||||||
} else if len(p.Data) > 0 {
|
} else if len(p.Data) > 0 {
|
||||||
log.Debug("TLS: Updating buffer with new TLS data from packet")
|
log.Debug("TLS: Updating buffer with new TLS data from packet")
|
||||||
if p.Flags&FlagLengthIncluded != 0 && st.Conn.expectedWriterByteCount == 0 {
|
if p.Flags&FlagLengthIncluded != 0 && p.st.Conn.expectedWriterByteCount == 0 {
|
||||||
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
|
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
|
||||||
st.Conn.expectedWriterByteCount = int(p.Length)
|
p.st.Conn.expectedWriterByteCount = int(p.Length)
|
||||||
} else if p.Flags&FlagLengthIncluded != 0 {
|
} else if p.Flags&FlagLengthIncluded != 0 {
|
||||||
log.Debug("TLS: No length included, not buffering")
|
log.Debug("TLS: No length included, not buffering")
|
||||||
st.Conn.expectedWriterByteCount = 0
|
p.st.Conn.expectedWriterByteCount = 0
|
||||||
}
|
}
|
||||||
st.Conn.UpdateData(p.Data)
|
p.st.Conn.UpdateData(p.Data)
|
||||||
if !st.Conn.NeedsMoreData() {
|
if !p.st.Conn.NeedsMoreData() {
|
||||||
// Wait for outbound data to be available
|
// Wait for outbound data to be available
|
||||||
st.Conn.OutboundData()
|
p.st.Conn.OutboundData()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If we need more data, send the client the go-ahead
|
// If we need more data, send the client the go-ahead
|
||||||
if st.Conn.NeedsMoreData() {
|
if p.st.Conn.NeedsMoreData() {
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: FlagNone,
|
Flags: FlagNone,
|
||||||
Length: 0,
|
Length: 0,
|
||||||
Data: []byte{},
|
Data: []byte{},
|
||||||
}, st
|
}
|
||||||
}
|
}
|
||||||
if st.HasMore() {
|
if p.st.HasMore() {
|
||||||
return p.sendNextChunk(st)
|
return p.sendNextChunk()
|
||||||
}
|
}
|
||||||
if st.Conn.writer.Len() == 0 && st.HandshakeDone {
|
if p.st.Conn.writer.Len() == 0 && p.st.HandshakeDone {
|
||||||
defer st.ContextCancel()
|
defer p.st.ContextCancel()
|
||||||
return protocol.EmptyPayload{
|
ctx.EndInnerProtocol(func(r *radius.Packet) *radius.Packet {
|
||||||
ModifyPacket: func(p *radius.Packet) *radius.Packet {
|
r.Code = radius.CodeAccessAccept
|
||||||
p.Code = radius.CodeAccessAccept
|
microsoft.MSMPPERecvKey_Set(r, p.st.MPPEKey[:32])
|
||||||
microsoft.MSMPPERecvKey_Set(p, st.MPPEKey[:32])
|
microsoft.MSMPPESendKey_Set(r, p.st.MPPEKey[64:64+32])
|
||||||
microsoft.MSMPPESendKey_Set(p, st.MPPEKey[64:64+32])
|
return r
|
||||||
rfc2865.UserName_SetString(p, "foo")
|
})
|
||||||
rfc2865.FramedMTU_Set(p, rfc2865.FramedMTU(1400))
|
return nil
|
||||||
return p
|
|
||||||
},
|
|
||||||
}, st
|
|
||||||
}
|
}
|
||||||
return p.startChunkedTransfer(st.Conn.OutboundData(), st)
|
return p.startChunkedTransfer(p.st.Conn.OutboundData())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) tlsInit(st *State) *State {
|
func (p *Payload) tlsInit(ctx tctx) {
|
||||||
log.Debug("TLS: no TLS connection in state yet, starting connection")
|
log.Debug("TLS: no TLS connection in state yet, starting connection")
|
||||||
st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
|
p.st.Context, p.st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
|
||||||
st.Conn = NewBuffConn(p.Data, st.Context)
|
p.st.Conn = NewBuffConn(p.Data, p.st.Context)
|
||||||
st.TLS = tls.Server(st.Conn, &tls.Config{
|
cfg := ctx.ProtocolSettings().Config.Clone()
|
||||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
log.Debugf("TLS: ClientHello: %+v\n", ch)
|
log.Debugf("TLS: ClientHello: %+v\n", chi)
|
||||||
st.ClientHello = ch
|
p.st.ClientHello = chi
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
}
|
||||||
ClientAuth: tls.RequireAnyClientCert,
|
p.st.TLS = tls.Server(p.st.Conn, cfg)
|
||||||
Certificates: certs,
|
|
||||||
})
|
|
||||||
go func() {
|
go func() {
|
||||||
err := st.TLS.HandshakeContext(st.Context)
|
err := p.st.TLS.HandshakeContext(p.st.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Debug("TLS: Handshake error")
|
log.WithError(err).Debug("TLS: Handshake error")
|
||||||
// TODO: Send a NAK to the client
|
// TODO: Send a NAK to the client
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debug("TLS: handshake done")
|
log.Debug("TLS: handshake done")
|
||||||
p.tlsHandshakeFinished(st)
|
p.tlsHandshakeFinished()
|
||||||
}()
|
}()
|
||||||
return st
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) tlsHandshakeFinished(st *State) {
|
func (p *Payload) tlsHandshakeFinished() {
|
||||||
cs := st.TLS.ConnectionState()
|
cs := p.st.TLS.ConnectionState()
|
||||||
label := "client EAP encryption"
|
label := "client EAP encryption"
|
||||||
var context []byte
|
var context []byte
|
||||||
switch cs.Version {
|
switch cs.Version {
|
||||||
@ -176,46 +156,46 @@ func (p *Payload) tlsHandshakeFinished(st *State) {
|
|||||||
}
|
}
|
||||||
ksm, err := cs.ExportKeyingMaterial(label, context, 64+64)
|
ksm, err := cs.ExportKeyingMaterial(label, context, 64+64)
|
||||||
log.Debugf("TLS: ksm % x %v", ksm, err)
|
log.Debugf("TLS: ksm % x %v", ksm, err)
|
||||||
st.MPPEKey = ksm
|
p.st.MPPEKey = ksm
|
||||||
st.HandshakeDone = true
|
p.st.HandshakeDone = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) {
|
func (p *Payload) startChunkedTransfer(data []byte) *Payload {
|
||||||
if len(data) > maxChunkSize {
|
if len(data) > maxChunkSize {
|
||||||
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
|
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
|
||||||
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...)
|
p.st.RemainingChunks = append(p.st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...)
|
||||||
st.TotalPayloadSize = len(data)
|
p.st.TotalPayloadSize = len(data)
|
||||||
return p.sendNextChunk(st)
|
return p.sendNextChunk()
|
||||||
}
|
}
|
||||||
log.WithField("length", len(data)).Debug("TLS: Sending data un-chunked")
|
log.WithField("length", len(data)).Debug("TLS: Sending data un-chunked")
|
||||||
st.Conn.writer.Reset()
|
p.st.Conn.writer.Reset()
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: FlagLengthIncluded,
|
Flags: FlagLengthIncluded,
|
||||||
Length: uint32(len(data)),
|
Length: uint32(len(data)),
|
||||||
Data: data,
|
Data: data,
|
||||||
}, st
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) sendNextChunk(st *State) (*Payload, *State) {
|
func (p *Payload) sendNextChunk() *Payload {
|
||||||
nextChunk := st.RemainingChunks[0]
|
nextChunk := p.st.RemainingChunks[0]
|
||||||
log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk")
|
log.WithField("raw", debug.FormatBytes(nextChunk)).Debug("TLS: Sending next chunk")
|
||||||
st.RemainingChunks = st.RemainingChunks[1:]
|
p.st.RemainingChunks = p.st.RemainingChunks[1:]
|
||||||
flags := FlagLengthIncluded
|
flags := FlagLengthIncluded
|
||||||
if st.HasMore() {
|
if p.st.HasMore() {
|
||||||
log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left")
|
log.WithField("chunks", len(p.st.RemainingChunks)).Debug("TLS: More chunks left")
|
||||||
flags += FlagMoreFragments
|
flags += FlagMoreFragments
|
||||||
} else {
|
} else {
|
||||||
// Last chunk, reset the connection buffers and pending payload size
|
// Last chunk, reset the connection buffers and pending payload size
|
||||||
defer func() {
|
defer func() {
|
||||||
log.Debug("TLS: Sent last chunk")
|
log.Debug("TLS: Sent last chunk")
|
||||||
st.Conn.writer.Reset()
|
p.st.Conn.writer.Reset()
|
||||||
st.TotalPayloadSize = 0
|
p.st.TotalPayloadSize = 0
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size")
|
log.WithField("length", p.st.TotalPayloadSize).Debug("TLS: Total payload size")
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: flags,
|
Flags: flags,
|
||||||
Length: uint32(st.TotalPayloadSize),
|
Length: uint32(p.st.TotalPayloadSize),
|
||||||
Data: nextChunk,
|
Data: nextChunk,
|
||||||
}, st
|
}
|
||||||
}
|
}
|
||||||
|
7
internal/outpost/radius/eap/tls/settings.go
Normal file
7
internal/outpost/radius/eap/tls/settings.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package tls
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
type Settings struct {
|
||||||
|
Config *tls.Config
|
||||||
|
}
|
@ -18,7 +18,8 @@ type State struct {
|
|||||||
ContextCancel context.CancelFunc
|
ContextCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewState() *State {
|
func NewState(c tctx) *State {
|
||||||
|
c.Log().Debug("TLS: new state")
|
||||||
return &State{
|
return &State{
|
||||||
RemainingChunks: make([][]byte, 0),
|
RemainingChunks: make([][]byte, 0),
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package radius
|
package radius
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
ttls "crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/flow"
|
"goauthentik.io/internal/outpost/flow"
|
||||||
"goauthentik.io/internal/outpost/radius/eap"
|
"goauthentik.io/internal/outpost/radius/eap"
|
||||||
|
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||||
"goauthentik.io/internal/outpost/radius/metrics"
|
"goauthentik.io/internal/outpost/radius/metrics"
|
||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
"layeh.com/radius/rfc2865"
|
"layeh.com/radius/rfc2865"
|
||||||
@ -122,7 +124,24 @@ func (pi *ProviderInstance) SetEAPState(key string, state *eap.State) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
func (pi *ProviderInstance) GetEAPSettings() eap.Settings {
|
||||||
|
// Testing
|
||||||
|
cert, err := ttls.LoadX509KeyPair(
|
||||||
|
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
|
||||||
|
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
return eap.Settings{
|
return eap.Settings{
|
||||||
ChallengesToOffer: []eap.Type{eap.TypeTLS},
|
ProtocolsToOffer: []eap.Type{eap.TypeTLS},
|
||||||
|
ProtocolSettings: map[eap.Type]interface{}{
|
||||||
|
eap.TypeTLS: tls.Settings{
|
||||||
|
Config: &ttls.Config{
|
||||||
|
Certificates: []ttls.Certificate{cert},
|
||||||
|
ClientAuth: ttls.RequireAnyClientCert,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user