Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-15 13:00:45 +02:00
parent fc5c0e2789
commit 9cee59537c
2 changed files with 27 additions and 20 deletions

View File

@ -2,6 +2,7 @@ package tls
import (
"bytes"
"context"
"net"
"time"
@ -12,20 +13,24 @@ type TLSConnection struct {
reader *bytes.Buffer
writer *bytes.Buffer
ctx context.Context
expectedWriterByteCount int
writtenByteCount int
}
func NewTLSConnection(initialData []byte) *TLSConnection {
func NewTLSConnection(initialData []byte, ctx context.Context) *TLSConnection {
c := &TLSConnection{
reader: bytes.NewBuffer(initialData),
writer: bytes.NewBuffer([]byte{}),
ctx: ctx,
}
return c
}
func (conn TLSConnection) OutboundData() []byte {
for {
// TODO cancel with conn.ctx
b := conn.writer.Bytes()
if len(b) < 1 {
log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...")
@ -51,6 +56,7 @@ func (conn TLSConnection) NeedsMoreData() bool {
func (conn *TLSConnection) Read(p []byte) (int, error) {
for {
// TODO cancel with conn.ctx
n, err := conn.reader.Read(p)
if n == 0 {
log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p))

View File

@ -16,6 +16,21 @@ import (
)
const maxChunkSize = 1000
const staleConnectionTimeout = 10
var certs = []tls.Certificate{}
func init() {
// Testing
cert, err := tls.LoadX509KeyPair(
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
)
if err != nil {
panic(err)
}
certs = append(certs, cert)
}
type Payload struct {
Flags Flag
@ -58,20 +73,6 @@ func (p *Payload) Encode() ([]byte, error) {
return buff, nil
}
var certs = []tls.Certificate{}
func init() {
// Testing
cert, err := tls.LoadX509KeyPair(
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
)
if err != nil {
panic(err)
}
certs = append(certs, cert)
}
func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
if stt == nil {
log.Debug("TLS: new state")
@ -88,7 +89,9 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
if st.TLS == nil {
log.Debug("TLS: no TLS connection in state yet, starting connection")
st.Conn = NewTLSConnection(p.Data)
ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
st.Context = ctx
st.Conn = NewTLSConnection(p.Data, st.Context)
st.TLS = tls.Server(st.Conn, &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
log.Debugf("TLS: ClientHello: %+v\n", ch)
@ -98,8 +101,6 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
ClientAuth: tls.RequireAnyClientCert,
Certificates: certs,
})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
st.Context = ctx
go func() {
defer cancel()
err := st.TLS.HandshakeContext(st.Context)
@ -171,11 +172,11 @@ func (p *Payload) handshakeFinished(st *State) {
}
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) {
flags := FlagNone
flags := FlagLengthIncluded
var dataToSend []byte
if len(data) > maxChunkSize {
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
flags += FlagMoreFragments + FlagLengthIncluded
flags += FlagMoreFragments
// Chunk data into correct chunks and add them to the list
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...)
dataToSend = st.RemainingChunks[0]