we're getting somewhere

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-15 01:44:05 +02:00
parent df21e678d6
commit ae59a3e576
3 changed files with 25 additions and 35 deletions

View File

@ -12,7 +12,8 @@ type TLSConnection struct {
reader *bytes.Buffer
writer *bytes.Buffer
bufferIncomingBytesCount uint32
expectedWriterByteCount int
writtenByteCount int
}
func NewTLSConnection(initialData []byte) *TLSConnection {
@ -35,37 +36,19 @@ func (conn TLSConnection) OutboundData() []byte {
}
}
func (conn TLSConnection) UpdateData(data []byte) {
func (conn *TLSConnection) UpdateData(data []byte) {
conn.reader.Write(data)
if conn.bufferIncomingBytesCount > 0 && conn.reader.Len() == int(conn.bufferIncomingBytesCount) {
conn.bufferIncomingBytesCount = 0
}
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.reader.Len(), conn.bufferIncomingBytesCount)
conn.writtenByteCount += len(data)
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.writtenByteCount, conn.expectedWriterByteCount)
}
// func (conn TLSConnection) Reset() {
// log.Debug("TLS(buffer): reset")
// conn.reader.Reset()
// conn.writer.Reset()
// }
func (conn TLSConnection) NeedsMoreData() bool {
if conn.bufferIncomingBytesCount > 0 {
return conn.reader.Len() < int(conn.bufferIncomingBytesCount)
if conn.expectedWriterByteCount > 0 {
return conn.reader.Len() < int(conn.expectedWriterByteCount)
}
return false
}
// func (conn TLSConnection) WaitForAttemptedRead() int {
// for {
// // log.Debug("TLS(buffer): waiting for attempted read")
// if conn.missingBytes == 0 {
// continue
// }
// return conn.missingBytes
// }
// }
func (conn *TLSConnection) Read(p []byte) (int, error) {
for {
n, err := conn.reader.Read(p)
@ -74,11 +57,17 @@ func (conn *TLSConnection) Read(p []byte) (int, error) {
time.Sleep(500 * time.Millisecond)
continue
}
if conn.reader.Len() < int(conn.bufferIncomingBytesCount) {
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.reader.Len()-int(conn.bufferIncomingBytesCount))
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) {
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len())
time.Sleep(500 * time.Millisecond)
continue
}
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) {
conn.expectedWriterByteCount = 0
}
if conn.reader.Len() == 0 {
conn.writtenByteCount = 0
}
log.Debugf("TLS(buffer): Read: %d from %d", len(p), n)
return n, err
}

View File

@ -7,5 +7,4 @@ const (
FlagMoreFragments Flag = 1 << 6
FlagTLSStart Flag = 1 << 5
FlagNone Flag = 0
FlagLengthMore Flag = 0xc0
)

View File

@ -105,16 +105,18 @@ func (p *Payload) Handle(stt any) (*Payload, *State) {
}()
} else if len(p.Data) > 0 {
log.Debug("TLS: Updating buffer with new TLS data from packet")
if p.Flags&FlagLengthMore != 0 && st.Conn.bufferIncomingBytesCount == 0 {
if p.Flags&FlagLengthIncluded != 0 && st.Conn.expectedWriterByteCount == 0 {
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
st.Conn.bufferIncomingBytesCount = p.Length
st.Conn.expectedWriterByteCount = int(p.Length)
} else if p.Flags&FlagLengthIncluded != 0 {
log.Debug("TLS: No length included, not buffering")
st.Conn.expectedWriterByteCount = 0
}
st.Conn.UpdateData(p.Data)
return &Payload{
Flags: FlagNone,
Length: 0,
Data: []byte{},
}, st
if !st.Conn.NeedsMoreData() {
// Wait for outbound data to be available
st.Conn.OutboundData()
}
}
// If we need more data, send the client the go-ahead
if st.Conn.NeedsMoreData() {
@ -166,7 +168,7 @@ func (p *Payload) sendNextChunk(st *State) (*Payload, *State) {
// Last chunk, reset the connection buffers and pending payload size
defer func() {
log.Debug("TLS: Sent last chunk")
st.Conn.reader.Reset()
st.Conn.writer.Reset()
st.TotalPayloadSize = 0
}()
}