Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-15 13:00:45 +02:00
parent fc5c0e2789
commit 9cee59537c
2 changed files with 27 additions and 20 deletions

View File

@ -2,6 +2,7 @@ package tls
import ( import (
"bytes" "bytes"
"context"
"net" "net"
"time" "time"
@ -12,20 +13,24 @@ type TLSConnection struct {
reader *bytes.Buffer reader *bytes.Buffer
writer *bytes.Buffer writer *bytes.Buffer
ctx context.Context
expectedWriterByteCount int expectedWriterByteCount int
writtenByteCount int writtenByteCount int
} }
func NewTLSConnection(initialData []byte) *TLSConnection { func NewTLSConnection(initialData []byte, ctx context.Context) *TLSConnection {
c := &TLSConnection{ c := &TLSConnection{
reader: bytes.NewBuffer(initialData), reader: bytes.NewBuffer(initialData),
writer: bytes.NewBuffer([]byte{}), writer: bytes.NewBuffer([]byte{}),
ctx: ctx,
} }
return c return c
} }
func (conn TLSConnection) OutboundData() []byte { func (conn TLSConnection) OutboundData() []byte {
for { for {
// TODO cancel with conn.ctx
b := conn.writer.Bytes() b := conn.writer.Bytes()
if len(b) < 1 { if len(b) < 1 {
log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...") log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...")
@ -51,6 +56,7 @@ func (conn TLSConnection) NeedsMoreData() bool {
func (conn *TLSConnection) Read(p []byte) (int, error) { func (conn *TLSConnection) Read(p []byte) (int, error) {
for { for {
// TODO cancel with conn.ctx
n, err := conn.reader.Read(p) n, err := conn.reader.Read(p)
if n == 0 { if n == 0 {
log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p)) log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p))

View File

@ -16,6 +16,21 @@ import (
) )
const maxChunkSize = 1000 const maxChunkSize = 1000
const staleConnectionTimeout = 10
var certs = []tls.Certificate{}
func init() {
// Testing
cert, err := tls.LoadX509KeyPair(
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
)
if err != nil {
panic(err)
}
certs = append(certs, cert)
}
type Payload struct { type Payload struct {
Flags Flag Flags Flag
@ -58,20 +73,6 @@ func (p *Payload) Encode() ([]byte, error) {
return buff, nil return buff, nil
} }
var certs = []tls.Certificate{}
func init() {
// Testing
cert, err := tls.LoadX509KeyPair(
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
)
if err != nil {
panic(err)
}
certs = append(certs, cert)
}
func (p *Payload) Handle(stt any) (protocol.Payload, *State) { func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
if stt == nil { if stt == nil {
log.Debug("TLS: new state") log.Debug("TLS: new state")
@ -88,7 +89,9 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
if st.TLS == nil { if st.TLS == nil {
log.Debug("TLS: no TLS connection in state yet, starting connection") log.Debug("TLS: no TLS connection in state yet, starting connection")
st.Conn = NewTLSConnection(p.Data) ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
st.Context = ctx
st.Conn = NewTLSConnection(p.Data, st.Context)
st.TLS = tls.Server(st.Conn, &tls.Config{ st.TLS = tls.Server(st.Conn, &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
log.Debugf("TLS: ClientHello: %+v\n", ch) log.Debugf("TLS: ClientHello: %+v\n", ch)
@ -98,8 +101,6 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
ClientAuth: tls.RequireAnyClientCert, ClientAuth: tls.RequireAnyClientCert,
Certificates: certs, Certificates: certs,
}) })
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
st.Context = ctx
go func() { go func() {
defer cancel() defer cancel()
err := st.TLS.HandshakeContext(st.Context) err := st.TLS.HandshakeContext(st.Context)
@ -171,11 +172,11 @@ func (p *Payload) handshakeFinished(st *State) {
} }
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) { func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) {
flags := FlagNone flags := FlagLengthIncluded
var dataToSend []byte var dataToSend []byte
if len(data) > maxChunkSize { if len(data) > maxChunkSize {
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked") log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
flags += FlagMoreFragments + FlagLengthIncluded flags += FlagMoreFragments
// Chunk data into correct chunks and add them to the list // Chunk data into correct chunks and add them to the list
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...) st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...)
dataToSend = st.RemainingChunks[0] dataToSend = st.RemainingChunks[0]