we're getting somewhere
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -12,7 +12,8 @@ type TLSConnection struct {
|
|||||||
reader *bytes.Buffer
|
reader *bytes.Buffer
|
||||||
writer *bytes.Buffer
|
writer *bytes.Buffer
|
||||||
|
|
||||||
bufferIncomingBytesCount uint32
|
expectedWriterByteCount int
|
||||||
|
writtenByteCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTLSConnection(initialData []byte) *TLSConnection {
|
func NewTLSConnection(initialData []byte) *TLSConnection {
|
||||||
@ -35,37 +36,19 @@ func (conn TLSConnection) OutboundData() []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn TLSConnection) UpdateData(data []byte) {
|
func (conn *TLSConnection) UpdateData(data []byte) {
|
||||||
conn.reader.Write(data)
|
conn.reader.Write(data)
|
||||||
if conn.bufferIncomingBytesCount > 0 && conn.reader.Len() == int(conn.bufferIncomingBytesCount) {
|
conn.writtenByteCount += len(data)
|
||||||
conn.bufferIncomingBytesCount = 0
|
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.writtenByteCount, conn.expectedWriterByteCount)
|
||||||
}
|
|
||||||
log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.reader.Len(), conn.bufferIncomingBytesCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (conn TLSConnection) Reset() {
|
|
||||||
// log.Debug("TLS(buffer): reset")
|
|
||||||
// conn.reader.Reset()
|
|
||||||
// conn.writer.Reset()
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (conn TLSConnection) NeedsMoreData() bool {
|
func (conn TLSConnection) NeedsMoreData() bool {
|
||||||
if conn.bufferIncomingBytesCount > 0 {
|
if conn.expectedWriterByteCount > 0 {
|
||||||
return conn.reader.Len() < int(conn.bufferIncomingBytesCount)
|
return conn.reader.Len() < int(conn.expectedWriterByteCount)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (conn TLSConnection) WaitForAttemptedRead() int {
|
|
||||||
// for {
|
|
||||||
// // log.Debug("TLS(buffer): waiting for attempted read")
|
|
||||||
// if conn.missingBytes == 0 {
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// return conn.missingBytes
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (conn *TLSConnection) Read(p []byte) (int, error) {
|
func (conn *TLSConnection) Read(p []byte) (int, error) {
|
||||||
for {
|
for {
|
||||||
n, err := conn.reader.Read(p)
|
n, err := conn.reader.Read(p)
|
||||||
@ -74,11 +57,17 @@ func (conn *TLSConnection) Read(p []byte) (int, error) {
|
|||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if conn.reader.Len() < int(conn.bufferIncomingBytesCount) {
|
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) {
|
||||||
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.reader.Len()-int(conn.bufferIncomingBytesCount))
|
log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len())
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) {
|
||||||
|
conn.expectedWriterByteCount = 0
|
||||||
|
}
|
||||||
|
if conn.reader.Len() == 0 {
|
||||||
|
conn.writtenByteCount = 0
|
||||||
|
}
|
||||||
log.Debugf("TLS(buffer): Read: %d from %d", len(p), n)
|
log.Debugf("TLS(buffer): Read: %d from %d", len(p), n)
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,5 +7,4 @@ const (
|
|||||||
FlagMoreFragments Flag = 1 << 6
|
FlagMoreFragments Flag = 1 << 6
|
||||||
FlagTLSStart Flag = 1 << 5
|
FlagTLSStart Flag = 1 << 5
|
||||||
FlagNone Flag = 0
|
FlagNone Flag = 0
|
||||||
FlagLengthMore Flag = 0xc0
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -105,16 +105,18 @@ func (p *Payload) Handle(stt any) (*Payload, *State) {
|
|||||||
}()
|
}()
|
||||||
} else if len(p.Data) > 0 {
|
} else if len(p.Data) > 0 {
|
||||||
log.Debug("TLS: Updating buffer with new TLS data from packet")
|
log.Debug("TLS: Updating buffer with new TLS data from packet")
|
||||||
if p.Flags&FlagLengthMore != 0 && st.Conn.bufferIncomingBytesCount == 0 {
|
if p.Flags&FlagLengthIncluded != 0 && st.Conn.expectedWriterByteCount == 0 {
|
||||||
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
|
log.Debugf("TLS: Expecting %d total bytes, will buffer", p.Length)
|
||||||
st.Conn.bufferIncomingBytesCount = p.Length
|
st.Conn.expectedWriterByteCount = int(p.Length)
|
||||||
|
} else if p.Flags&FlagLengthIncluded != 0 {
|
||||||
|
log.Debug("TLS: No length included, not buffering")
|
||||||
|
st.Conn.expectedWriterByteCount = 0
|
||||||
}
|
}
|
||||||
st.Conn.UpdateData(p.Data)
|
st.Conn.UpdateData(p.Data)
|
||||||
return &Payload{
|
if !st.Conn.NeedsMoreData() {
|
||||||
Flags: FlagNone,
|
// Wait for outbound data to be available
|
||||||
Length: 0,
|
st.Conn.OutboundData()
|
||||||
Data: []byte{},
|
}
|
||||||
}, st
|
|
||||||
}
|
}
|
||||||
// If we need more data, send the client the go-ahead
|
// If we need more data, send the client the go-ahead
|
||||||
if st.Conn.NeedsMoreData() {
|
if st.Conn.NeedsMoreData() {
|
||||||
@ -166,7 +168,7 @@ func (p *Payload) sendNextChunk(st *State) (*Payload, *State) {
|
|||||||
// Last chunk, reset the connection buffers and pending payload size
|
// Last chunk, reset the connection buffers and pending payload size
|
||||||
defer func() {
|
defer func() {
|
||||||
log.Debug("TLS: Sent last chunk")
|
log.Debug("TLS: Sent last chunk")
|
||||||
st.Conn.reader.Reset()
|
st.Conn.writer.Reset()
|
||||||
st.TotalPayloadSize = 0
|
st.TotalPayloadSize = 0
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user