diff --git a/internal/outpost/radius/eap/tls/buff_conn.go b/internal/outpost/radius/eap/tls/buff_conn.go index 7455763afa..05e303625f 100644 --- a/internal/outpost/radius/eap/tls/buff_conn.go +++ b/internal/outpost/radius/eap/tls/buff_conn.go @@ -3,9 +3,11 @@ package tls import ( "bytes" "context" + "errors" "net" "time" + "github.com/avast/retry-go/v4" log "github.com/sirupsen/logrus" ) @@ -17,6 +19,8 @@ type BuffConn struct { expectedWriterByteCount int writtenByteCount int + + retryOptions []retry.Option } func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn { @@ -24,21 +28,34 @@ func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn { reader: bytes.NewBuffer(initialData), writer: bytes.NewBuffer([]byte{}), ctx: ctx, + retryOptions: []retry.Option{ + retry.Context(ctx), + retry.Delay(10 * time.Microsecond), + retry.DelayType(retry.BackOffDelay), + retry.MaxDelay(100 * time.Millisecond), + retry.Attempts(0), + }, } return c } +var errStall = errors.New("Stall") + func (conn BuffConn) OutboundData() []byte { - for { - // TODO cancel with conn.ctx - b := conn.writer.Bytes() - if len(b) < 1 { - log.Debug("TLS(buffcon): Attempted retrieve from empty buffer, stalling...") - time.Sleep(1 * time.Second) - continue - } - return b + d, err := retry.DoWithData( + func() ([]byte, error) { + b := conn.writer.Bytes() + if len(b) < 1 { + return nil, errStall + } + return b, nil + }, + conn.retryOptions..., + ) + if err != nil { + return []byte{} } + return d } func (conn *BuffConn) UpdateData(data []byte) { @@ -55,28 +72,29 @@ func (conn BuffConn) NeedsMoreData() bool { } func (conn *BuffConn) Read(p []byte) (int, error) { - for { - // TODO cancel with conn.ctx - n, err := conn.reader.Read(p) - if n == 0 { - log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p)) - time.Sleep(100 * time.Millisecond) - continue - } - if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { - log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) - time.Sleep(100 * time.Millisecond) - continue - } - if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { - conn.expectedWriterByteCount = 0 - } - if conn.reader.Len() == 0 { - conn.writtenByteCount = 0 - } - log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n) - return n, err - } + d, err := retry.DoWithData( + func() (int, error) { + n, err := conn.reader.Read(p) + if n == 0 { + log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p)) + return 0, errStall + } + if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { + log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) + return 0, errStall + } + if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { + conn.expectedWriterByteCount = 0 + } + if conn.reader.Len() == 0 { + conn.writtenByteCount = 0 + } + log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n) + return n, err + }, + conn.retryOptions..., + ) + return d, err } func (conn BuffConn) Write(p []byte) (int, error) { diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index 589d10a7f1..0a199bac4d 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -90,8 +90,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { if st.TLS == nil { log.Debug("TLS: no TLS connection in state yet, starting connection") - ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) - st.Context = ctx + st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) st.Conn = NewBuffConn(p.Data, st.Context) st.TLS = tls.Server(st.Conn, &tls.Config{ GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { @@ -109,7 +108,6 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { }, }) go func() { - defer cancel() err := st.TLS.HandshakeContext(st.Context) if err != nil { log.WithError(err).Debug("TLS: Handshake error") @@ -145,6 +143,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { return p.sendNextChunk(st) } if st.Conn.writer.Len() == 0 && st.HandshakeDone { + defer st.ContextCancel() return protocol.EmptyPayload{ ModifyPacket: func(p *radius.Packet) *radius.Packet { p.Code = radius.CodeAccessAccept diff --git a/internal/outpost/radius/eap/tls/state.go b/internal/outpost/radius/eap/tls/state.go index cdd03714f6..adf23e918b 100644 --- a/internal/outpost/radius/eap/tls/state.go +++ b/internal/outpost/radius/eap/tls/state.go @@ -15,6 +15,7 @@ type State struct { TLS *tls.Conn Conn *BuffConn Context context.Context + ContextCancel context.CancelFunc } func NewState() *State {