use tighter retry that cancels and backs off
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -3,9 +3,11 @@ package tls | |||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"errors" | ||||||
| 	"net" | 	"net" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/avast/retry-go/v4" | ||||||
| 	log "github.com/sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @ -17,6 +19,8 @@ type BuffConn struct { | |||||||
|  |  | ||||||
| 	expectedWriterByteCount int | 	expectedWriterByteCount int | ||||||
| 	writtenByteCount        int | 	writtenByteCount        int | ||||||
|  |  | ||||||
|  | 	retryOptions []retry.Option | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn { | func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn { | ||||||
| @ -24,21 +28,34 @@ func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn { | |||||||
| 		reader: bytes.NewBuffer(initialData), | 		reader: bytes.NewBuffer(initialData), | ||||||
| 		writer: bytes.NewBuffer([]byte{}), | 		writer: bytes.NewBuffer([]byte{}), | ||||||
| 		ctx:    ctx, | 		ctx:    ctx, | ||||||
|  | 		retryOptions: []retry.Option{ | ||||||
|  | 			retry.Context(ctx), | ||||||
|  | 			retry.Delay(10 * time.Microsecond), | ||||||
|  | 			retry.DelayType(retry.BackOffDelay), | ||||||
|  | 			retry.MaxDelay(100 * time.Millisecond), | ||||||
|  | 			retry.Attempts(0), | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| 	return c | 	return c | ||||||
| } | } | ||||||
|  |  | ||||||
|  | var errStall = errors.New("Stall") | ||||||
|  |  | ||||||
| func (conn BuffConn) OutboundData() []byte { | func (conn BuffConn) OutboundData() []byte { | ||||||
| 	for { | 	d, err := retry.DoWithData( | ||||||
| 		// TODO cancel with conn.ctx | 		func() ([]byte, error) { | ||||||
| 			b := conn.writer.Bytes() | 			b := conn.writer.Bytes() | ||||||
| 			if len(b) < 1 { | 			if len(b) < 1 { | ||||||
| 			log.Debug("TLS(buffcon): Attempted retrieve from empty buffer, stalling...") | 				return nil, errStall | ||||||
| 			time.Sleep(1 * time.Second) |  | ||||||
| 			continue |  | ||||||
| 			} | 			} | ||||||
| 		return b | 			return b, nil | ||||||
|  | 		}, | ||||||
|  | 		conn.retryOptions..., | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return []byte{} | ||||||
| 	} | 	} | ||||||
|  | 	return d | ||||||
| } | } | ||||||
|  |  | ||||||
| func (conn *BuffConn) UpdateData(data []byte) { | func (conn *BuffConn) UpdateData(data []byte) { | ||||||
| @ -55,18 +72,16 @@ func (conn BuffConn) NeedsMoreData() bool { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (conn *BuffConn) Read(p []byte) (int, error) { | func (conn *BuffConn) Read(p []byte) (int, error) { | ||||||
| 	for { | 	d, err := retry.DoWithData( | ||||||
| 		// TODO cancel with conn.ctx | 		func() (int, error) { | ||||||
| 			n, err := conn.reader.Read(p) | 			n, err := conn.reader.Read(p) | ||||||
| 			if n == 0 { | 			if n == 0 { | ||||||
| 				log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p)) | 				log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p)) | ||||||
| 			time.Sleep(100 * time.Millisecond) | 				return 0, errStall | ||||||
| 			continue |  | ||||||
| 			} | 			} | ||||||
| 			if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { | 			if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { | ||||||
| 				log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) | 				log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) | ||||||
| 			time.Sleep(100 * time.Millisecond) | 				return 0, errStall | ||||||
| 			continue |  | ||||||
| 			} | 			} | ||||||
| 			if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { | 			if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { | ||||||
| 				conn.expectedWriterByteCount = 0 | 				conn.expectedWriterByteCount = 0 | ||||||
| @ -76,7 +91,10 @@ func (conn *BuffConn) Read(p []byte) (int, error) { | |||||||
| 			} | 			} | ||||||
| 			log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n) | 			log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n) | ||||||
| 			return n, err | 			return n, err | ||||||
| 	} | 		}, | ||||||
|  | 		conn.retryOptions..., | ||||||
|  | 	) | ||||||
|  | 	return d, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (conn BuffConn) Write(p []byte) (int, error) { | func (conn BuffConn) Write(p []byte) (int, error) { | ||||||
|  | |||||||
| @ -90,8 +90,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { | |||||||
|  |  | ||||||
| 	if st.TLS == nil { | 	if st.TLS == nil { | ||||||
| 		log.Debug("TLS: no TLS connection in state yet, starting connection") | 		log.Debug("TLS: no TLS connection in state yet, starting connection") | ||||||
| 		ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) | 		st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) | ||||||
| 		st.Context = ctx |  | ||||||
| 		st.Conn = NewBuffConn(p.Data, st.Context) | 		st.Conn = NewBuffConn(p.Data, st.Context) | ||||||
| 		st.TLS = tls.Server(st.Conn, &tls.Config{ | 		st.TLS = tls.Server(st.Conn, &tls.Config{ | ||||||
| 			GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { | 			GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { | ||||||
| @ -109,7 +108,6 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { | |||||||
| 			}, | 			}, | ||||||
| 		}) | 		}) | ||||||
| 		go func() { | 		go func() { | ||||||
| 			defer cancel() |  | ||||||
| 			err := st.TLS.HandshakeContext(st.Context) | 			err := st.TLS.HandshakeContext(st.Context) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.WithError(err).Debug("TLS: Handshake error") | 				log.WithError(err).Debug("TLS: Handshake error") | ||||||
| @ -145,6 +143,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { | |||||||
| 		return p.sendNextChunk(st) | 		return p.sendNextChunk(st) | ||||||
| 	} | 	} | ||||||
| 	if st.Conn.writer.Len() == 0 && st.HandshakeDone { | 	if st.Conn.writer.Len() == 0 && st.HandshakeDone { | ||||||
|  | 		defer st.ContextCancel() | ||||||
| 		return protocol.EmptyPayload{ | 		return protocol.EmptyPayload{ | ||||||
| 			ModifyPacket: func(p *radius.Packet) *radius.Packet { | 			ModifyPacket: func(p *radius.Packet) *radius.Packet { | ||||||
| 				p.Code = radius.CodeAccessAccept | 				p.Code = radius.CodeAccessAccept | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ type State struct { | |||||||
| 	TLS              *tls.Conn | 	TLS              *tls.Conn | ||||||
| 	Conn             *BuffConn | 	Conn             *BuffConn | ||||||
| 	Context          context.Context | 	Context          context.Context | ||||||
|  | 	ContextCancel    context.CancelFunc | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewState() *State { | func NewState() *State { | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer