@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
|
||||||
"github.com/gorilla/securecookie"
|
"github.com/gorilla/securecookie"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
"layeh.com/radius/rfc2865"
|
"layeh.com/radius/rfc2865"
|
||||||
@ -28,6 +29,7 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
|||||||
res, newState := p.GetChallengeForType(st, nextChallengeToOffer)
|
res, newState := p.GetChallengeForType(st, nextChallengeToOffer)
|
||||||
stm.SetEAPState(rst, newState)
|
stm.SetEAPState(rst, newState)
|
||||||
|
|
||||||
|
log.Debug("EAP: encapsulating challenge")
|
||||||
rres := r.Response(radius.CodeAccessChallenge)
|
rres := r.Response(radius.CodeAccessChallenge)
|
||||||
rfc2865.State_SetString(rres, rst)
|
rfc2865.State_SetString(rres, rst)
|
||||||
eapEncoded, err := res.Encode()
|
eapEncoded, err := res.Encode()
|
||||||
@ -36,7 +38,6 @@ func (p *Packet) Handle(stm StateManager, w radius.ResponseWriter, r *radius.Pac
|
|||||||
}
|
}
|
||||||
rfc2869.EAPMessage_Set(rres, eapEncoded)
|
rfc2869.EAPMessage_Set(rres, eapEncoded)
|
||||||
p.setMessageAuthenticator(rres)
|
p.setMessageAuthenticator(rres)
|
||||||
// debug.DebugPacket(rres)
|
|
||||||
err = w.Write(rres)
|
err = w.Write(rres)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/debug"
|
"goauthentik.io/internal/outpost/radius/eap/debug"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/tls"
|
"goauthentik.io/internal/outpost/radius/eap/tls"
|
||||||
)
|
)
|
||||||
@ -63,7 +63,7 @@ func Decode(raw []byte) (*Packet, error) {
|
|||||||
}
|
}
|
||||||
packet.Payload = emptyPayload(packet.msgType)
|
packet.Payload = emptyPayload(packet.msgType)
|
||||||
packet.rawPayload = raw[5:]
|
packet.rawPayload = raw[5:]
|
||||||
logrus.WithField("raw", debug.FormatBytes(raw[5:])).Debug("EAP decode raw")
|
log.WithField("raw", debug.FormatBytes(raw)).Debug("EAP: decode raw")
|
||||||
err := packet.Payload.Decode(raw[5:])
|
err := packet.Payload.Decode(raw[5:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TLSConnection struct {
|
type TLSConnection struct {
|
||||||
@ -19,17 +21,38 @@ func NewTLSConnection(initialData []byte) TLSConnection {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn TLSConnection) Read(p []byte) (int, error) { return conn.reader.Read(p) }
|
func (conn TLSConnection) TLSData() []byte {
|
||||||
|
return conn.writer.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn TLSConnection) UpdateData(data []byte) {
|
||||||
|
conn.reader.Reset()
|
||||||
|
conn.reader.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
func (conn TLSConnection) Read(p []byte) (int, error) {
|
||||||
|
log.Debugf("TLS(buffer): Read: %d from %d", len(p), conn.reader.Len())
|
||||||
|
for {
|
||||||
|
n, err := conn.reader.Read(p)
|
||||||
|
if n == 0 {
|
||||||
|
log.Debug("TLS(buffer): Attempted read from empty buffer, stalling...")
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (conn TLSConnection) Write(p []byte) (int, error) {
|
func (conn TLSConnection) Write(p []byte) (int, error) {
|
||||||
|
log.Debugf("TLS(buffer): Write: %d", len(p))
|
||||||
return conn.writer.Write(p)
|
return conn.writer.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn TLSConnection) Close() error { return nil }
|
func (conn TLSConnection) Close() error { return nil }
|
||||||
func (conn TLSConnection) LocalAddr() net.Addr { return nil }
|
func (conn TLSConnection) LocalAddr() net.Addr { return nil }
|
||||||
func (conn TLSConnection) RemoteAddr() net.Addr { return nil }
|
func (conn TLSConnection) RemoteAddr() net.Addr { return nil }
|
||||||
func (conn TLSConnection) SetDeadline(t time.Time) error { return nil }
|
func (conn TLSConnection) SetDeadline(t time.Time) error { return nil }
|
||||||
func (conn TLSConnection) SetReadDeadline(t time.Time) error { return nil }
|
func (conn TLSConnection) SetReadDeadline(t time.Time) error { return nil }
|
||||||
func (conn TLSConnection) SetWriteDeadline(t time.Time) error { return nil }
|
func (conn TLSConnection) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
func (conn TLSConnection) TLSData() []byte {
|
|
||||||
return conn.writer.Bytes()
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
package tls
|
package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"slices"
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/radius/eap/debug"
|
"goauthentik.io/internal/outpost/radius/eap/debug"
|
||||||
@ -27,6 +29,7 @@ func (p *Payload) Decode(raw []byte) error {
|
|||||||
} else {
|
} else {
|
||||||
p.Data = raw[1:]
|
p.Data = raw[1:]
|
||||||
}
|
}
|
||||||
|
log.WithField("raw", debug.FormatBytes(p.Data)).WithField("flags", p.Flags).Debug("TLS: decode raw")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,69 +71,74 @@ func (p *Payload) Handle(stt any) (*Payload, State) {
|
|||||||
stt = NewState()
|
stt = NewState()
|
||||||
}
|
}
|
||||||
st := stt.(State)
|
st := stt.(State)
|
||||||
log.WithField("flags", p.Flags).Debug("Got TLS Packet")
|
|
||||||
if !st.HasStarted {
|
if !st.HasStarted {
|
||||||
st.HasStarted = true
|
st.HasStarted = true
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: FlagTLSStart,
|
Flags: FlagTLSStart,
|
||||||
}, st
|
}, st
|
||||||
}
|
}
|
||||||
if st.HasMore() {
|
|
||||||
return p.sendNextChunk(st)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithField("raw", debug.FormatBytes(p.Data)).Debug("TLS: Decode raw")
|
|
||||||
|
|
||||||
tc := NewTLSConnection(p.Data)
|
|
||||||
if st.TLS == nil {
|
if st.TLS == nil {
|
||||||
log.Debug("no TLS connection in state yet, starting connection")
|
log.Debug("TLS: no TLS connection in state yet, starting connection")
|
||||||
st.TLS = tls.Server(tc, &tls.Config{
|
st.Conn = NewTLSConnection(p.Data)
|
||||||
|
st.TLS = tls.Server(st.Conn, &tls.Config{
|
||||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
log.Debugf("%+v\n", argHello)
|
log.Debugf("TLS: ClientHello: %+v\n", argHello)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
ClientAuth: tls.RequireAnyClientCert,
|
ClientAuth: tls.RequireAnyClientCert,
|
||||||
Certificates: certs,
|
Certificates: certs,
|
||||||
})
|
})
|
||||||
err := st.TLS.Handshake()
|
st.Context, _ = context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
log.WithError(err).Debug("TLS: Handshake error")
|
go func() {
|
||||||
|
err := st.TLS.HandshakeContext(st.Context)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Debug("TLS: Handshake error")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else if len(p.Data) > 0 {
|
||||||
|
log.Debug("TLS: Updating buffer with new TLS data from packet")
|
||||||
|
st.Conn.UpdateData(p.Data)
|
||||||
}
|
}
|
||||||
return p.sendDataChunked(tc.TLSData(), st)
|
if st.HasMore() {
|
||||||
|
return p.sendNextChunk(st)
|
||||||
|
}
|
||||||
|
return p.startChunkedTransfer(st.Conn.TLSData(), st)
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxChunkSize = 1000
|
const maxChunkSize = 1000
|
||||||
|
|
||||||
func (p *Payload) sendDataChunked(data []byte, st State) (*Payload, State) {
|
func (p *Payload) startChunkedTransfer(data []byte, st State) (*Payload, State) {
|
||||||
flags := FlagLengthIncluded
|
flags := FlagLengthIncluded
|
||||||
var dataToSend []byte
|
var dataToSend []byte
|
||||||
if len(data) > maxChunkSize {
|
if len(data) > maxChunkSize {
|
||||||
log.WithField("length", len(data)).Debug("Data needs to be chunked")
|
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
|
||||||
flags += FlagMoreFragments
|
flags += FlagMoreFragments
|
||||||
dataToSend = data[:maxChunkSize]
|
dataToSend = data[:maxChunkSize]
|
||||||
remainingData := data[maxChunkSize:]
|
remainingData := data[maxChunkSize:]
|
||||||
// Chunk remaining data into correct chunks and add them to the list
|
// Chunk remaining data into correct chunks and add them to the list
|
||||||
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(remainingData, maxChunkSize))...)
|
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(remainingData, maxChunkSize))...)
|
||||||
st.TotalPayloadSize = len(st.RemainingChunks) * maxChunkSize
|
st.TotalPayloadSize = len(data)
|
||||||
} else {
|
} else {
|
||||||
dataToSend = data
|
dataToSend = data
|
||||||
}
|
}
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: flags,
|
Flags: flags,
|
||||||
Length: uint32(len(data) + 5),
|
Length: uint32(st.TotalPayloadSize),
|
||||||
Data: dataToSend,
|
Data: dataToSend,
|
||||||
}, st
|
}, st
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) sendNextChunk(st State) (*Payload, State) {
|
func (p *Payload) sendNextChunk(st State) (*Payload, State) {
|
||||||
log.Debug("Sending next chunk")
|
log.Debug("TLS: Sending next chunk")
|
||||||
nextChunk := st.RemainingChunks[0]
|
nextChunk := st.RemainingChunks[0]
|
||||||
st.RemainingChunks = st.RemainingChunks[1:]
|
st.RemainingChunks = st.RemainingChunks[1:]
|
||||||
flags := FlagLengthIncluded
|
flags := FlagLengthIncluded
|
||||||
if st.HasMore() {
|
if st.HasMore() {
|
||||||
log.WithField("chunks", len(st.RemainingChunks)).Debug("More chunks left")
|
log.WithField("chunks", len(st.RemainingChunks)).Debug("TLS: More chunks left")
|
||||||
flags += FlagMoreFragments
|
flags += FlagMoreFragments
|
||||||
}
|
}
|
||||||
log.WithField("length", st.TotalPayloadSize).Debug("Total payload size")
|
log.WithField("length", st.TotalPayloadSize).Debug("TLS: Total payload size")
|
||||||
return &Payload{
|
return &Payload{
|
||||||
Flags: flags,
|
Flags: flags,
|
||||||
Length: uint32(st.TotalPayloadSize),
|
Length: uint32(st.TotalPayloadSize),
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
package tls
|
package tls
|
||||||
|
|
||||||
import "crypto/tls"
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
)
|
||||||
|
|
||||||
type State struct {
|
type State struct {
|
||||||
HasStarted bool
|
HasStarted bool
|
||||||
RemainingChunks [][]byte
|
RemainingChunks [][]byte
|
||||||
TotalPayloadSize int
|
TotalPayloadSize int
|
||||||
TLS *tls.Conn
|
TLS *tls.Conn
|
||||||
|
Conn TLSConnection
|
||||||
|
Context context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewState() State {
|
func NewState() State {
|
||||||
|
|||||||
Reference in New Issue
Block a user