we're getting somewhere
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -12,7 +12,8 @@ type TLSConnection struct {
|
||||
reader *bytes.Buffer
|
||||
writer *bytes.Buffer
|
||||
|
||||
bufferIncomingBytesCount uint32
|
||||
expectedWriterByteCount int
|
||||
writtenByteCount int
|
||||
}
|
||||
|
||||
func NewTLSConnection(initialData []byte) *TLSConnection {
|
||||
@ -35,37 +36,19 @@ func (conn TLSConnection) OutboundData() []byte {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn TLSConnection) UpdateData(data []byte) {
|
||||
func (conn *TLSConnection) UpdateData(data []byte) {
|
||||
conn.reader.Write(data)
|
||||
if conn.bufferIncomingBytesCount > 0 && conn.reader.Len() == int(conn.bufferIncomingBytesCount) {
|
||||
conn.bufferIncomingBytesCount = 0
|
||||
}
|
||||
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.reader.Len(), conn.bufferIncomingBytesCount)
|
||||
conn.writtenByteCount += len(data)
|
||||
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.writtenByteCount, conn.expectedWriterByteCount)
|
||||
}
|
||||
|
||||
// func (conn TLSConnection) Reset() {
|
||||
// log.Debug("TLS(buffer): reset")
|
||||
// conn.reader.Reset()
|
||||
// conn.writer.Reset()
|
||||
// }
|
||||
|
||||
func (conn TLSConnection) NeedsMoreData() bool {
|
||||
if conn.bufferIncomingBytesCount > 0 {
|
||||
return conn.reader.Len() < int(conn.bufferIncomingBytesCount)
|
||||
if conn.expectedWriterByteCount > 0 {
|
||||
return conn.reader.Len() < int(conn.expectedWriterByteCount)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// func (conn TLSConnection) WaitForAttemptedRead() int {
|
||||
// for {
|
||||
// // log.Debug("TLS(buffer): waiting for attempted read")
|
||||
// if conn.missingBytes == 0 {
|
||||
// continue
|
||||
// }
|
||||
// return conn.missingBytes
|
||||
// }
|
||||
// }
|
||||
|
||||
func (conn *TLSConnection) Read(p []byte) (int, error) {
|
||||
for {
|
||||
n, err := conn.reader.Read(p)
|
||||
@ -74,11 +57,17 @@ func (conn *TLSConnection) Read(p []byte) (int, error) {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if conn.reader.Len() < int(conn.bufferIncomingBytesCount) {
|
||||
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.reader.Len()-int(conn.bufferIncomingBytesCount))
|
||||
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) {
|
||||
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len())
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) {
|
||||
conn.expectedWriterByteCount = 0
|
||||
}
|
||||
if conn.reader.Len() == 0 {
|
||||
conn.writtenByteCount = 0
|
||||
}
|
||||
log.Debugf("TLS(buffer): Read: %d from %d", len(p), n)
|
||||
return n, err
|
||||
}
|
||||
|
@ -7,5 +7,4 @@ const (
|
||||
FlagMoreFragments Flag = 1 << 6
|
||||
FlagTLSStart Flag = 1 << 5
|
||||
FlagNone Flag = 0
|
||||
FlagLengthMore Flag = 0xc0
|
||||
)
|
||||
|
@ -105,16 +105,18 @@ func (p *Payload) Handle(stt any) (*Payload, *State) {
|
||||
}()
|
||||
} else if len(p.Data) > 0 {
|
||||
log.Debug("TLS: Updating buffer with new TLS data from packet")
|
||||
if p.Flags&FlagLengthMore != 0 && st.Conn.bufferIncomingBytesCount == 0 {
|
||||
if p.Flags&FlagLengthIncluded != 0 && st.Conn.expectedWriterByteCount == 0 {
|
||||
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
|
||||
st.Conn.bufferIncomingBytesCount = p.Length
|
||||
st.Conn.expectedWriterByteCount = int(p.Length)
|
||||
} else if p.Flags&FlagLengthIncluded != 0 {
|
||||
log.Debug("TLS: No length included, not buffering")
|
||||
st.Conn.expectedWriterByteCount = 0
|
||||
}
|
||||
st.Conn.UpdateData(p.Data)
|
||||
return &Payload{
|
||||
Flags: FlagNone,
|
||||
Length: 0,
|
||||
Data: []byte{},
|
||||
}, st
|
||||
if !st.Conn.NeedsMoreData() {
|
||||
// Wait for outbound data to be available
|
||||
st.Conn.OutboundData()
|
||||
}
|
||||
}
|
||||
// If we need more data, send the client the go-ahead
|
||||
if st.Conn.NeedsMoreData() {
|
||||
@ -166,7 +168,7 @@ func (p *Payload) sendNextChunk(st *State) (*Payload, *State) {
|
||||
// Last chunk, reset the connection buffers and pending payload size
|
||||
defer func() {
|
||||
log.Debug("TLS: Sent last chunk")
|
||||
st.Conn.reader.Reset()
|
||||
st.Conn.writer.Reset()
|
||||
st.TotalPayloadSize = 0
|
||||
}()
|
||||
}
|
||||
|
Reference in New Issue
Block a user