use tighter retry that cancels and backs off
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -3,9 +3,11 @@ package tls
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/avast/retry-go/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,6 +19,8 @@ type BuffConn struct {
|
|||||||
|
|
||||||
expectedWriterByteCount int
|
expectedWriterByteCount int
|
||||||
writtenByteCount int
|
writtenByteCount int
|
||||||
|
|
||||||
|
retryOptions []retry.Option
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn {
|
func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn {
|
||||||
@ -24,21 +28,34 @@ func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn {
|
|||||||
reader: bytes.NewBuffer(initialData),
|
reader: bytes.NewBuffer(initialData),
|
||||||
writer: bytes.NewBuffer([]byte{}),
|
writer: bytes.NewBuffer([]byte{}),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
retryOptions: []retry.Option{
|
||||||
|
retry.Context(ctx),
|
||||||
|
retry.Delay(10 * time.Microsecond),
|
||||||
|
retry.DelayType(retry.BackOffDelay),
|
||||||
|
retry.MaxDelay(100 * time.Millisecond),
|
||||||
|
retry.Attempts(0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errStall = errors.New("Stall")
|
||||||
|
|
||||||
func (conn BuffConn) OutboundData() []byte {
|
func (conn BuffConn) OutboundData() []byte {
|
||||||
for {
|
d, err := retry.DoWithData(
|
||||||
// TODO cancel with conn.ctx
|
func() ([]byte, error) {
|
||||||
b := conn.writer.Bytes()
|
b := conn.writer.Bytes()
|
||||||
if len(b) < 1 {
|
if len(b) < 1 {
|
||||||
log.Debug("TLS(buffcon): Attempted retrieve from empty buffer, stalling...")
|
return nil, errStall
|
||||||
time.Sleep(1 * time.Second)
|
}
|
||||||
continue
|
return b, nil
|
||||||
}
|
},
|
||||||
return b
|
conn.retryOptions...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *BuffConn) UpdateData(data []byte) {
|
func (conn *BuffConn) UpdateData(data []byte) {
|
||||||
@ -55,28 +72,29 @@ func (conn BuffConn) NeedsMoreData() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *BuffConn) Read(p []byte) (int, error) {
|
func (conn *BuffConn) Read(p []byte) (int, error) {
|
||||||
for {
|
d, err := retry.DoWithData(
|
||||||
// TODO cancel with conn.ctx
|
func() (int, error) {
|
||||||
n, err := conn.reader.Read(p)
|
n, err := conn.reader.Read(p)
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p))
|
log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p))
|
||||||
time.Sleep(100 * time.Millisecond)
|
return 0, errStall
|
||||||
continue
|
}
|
||||||
}
|
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) {
|
||||||
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) {
|
log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len())
|
||||||
log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len())
|
return 0, errStall
|
||||||
time.Sleep(100 * time.Millisecond)
|
}
|
||||||
continue
|
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) {
|
||||||
}
|
conn.expectedWriterByteCount = 0
|
||||||
if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) {
|
}
|
||||||
conn.expectedWriterByteCount = 0
|
if conn.reader.Len() == 0 {
|
||||||
}
|
conn.writtenByteCount = 0
|
||||||
if conn.reader.Len() == 0 {
|
}
|
||||||
conn.writtenByteCount = 0
|
log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n)
|
||||||
}
|
return n, err
|
||||||
log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n)
|
},
|
||||||
return n, err
|
conn.retryOptions...,
|
||||||
}
|
)
|
||||||
|
return d, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn BuffConn) Write(p []byte) (int, error) {
|
func (conn BuffConn) Write(p []byte) (int, error) {
|
||||||
|
|||||||
@ -90,8 +90,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
|||||||
|
|
||||||
if st.TLS == nil {
|
if st.TLS == nil {
|
||||||
log.Debug("TLS: no TLS connection in state yet, starting connection")
|
log.Debug("TLS: no TLS connection in state yet, starting connection")
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
|
st.Context, st.ContextCancel = context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
|
||||||
st.Context = ctx
|
|
||||||
st.Conn = NewBuffConn(p.Data, st.Context)
|
st.Conn = NewBuffConn(p.Data, st.Context)
|
||||||
st.TLS = tls.Server(st.Conn, &tls.Config{
|
st.TLS = tls.Server(st.Conn, &tls.Config{
|
||||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
@ -109,7 +108,6 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
go func() {
|
go func() {
|
||||||
defer cancel()
|
|
||||||
err := st.TLS.HandshakeContext(st.Context)
|
err := st.TLS.HandshakeContext(st.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Debug("TLS: Handshake error")
|
log.WithError(err).Debug("TLS: Handshake error")
|
||||||
@ -145,6 +143,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
|||||||
return p.sendNextChunk(st)
|
return p.sendNextChunk(st)
|
||||||
}
|
}
|
||||||
if st.Conn.writer.Len() == 0 && st.HandshakeDone {
|
if st.Conn.writer.Len() == 0 && st.HandshakeDone {
|
||||||
|
defer st.ContextCancel()
|
||||||
return protocol.EmptyPayload{
|
return protocol.EmptyPayload{
|
||||||
ModifyPacket: func(p *radius.Packet) *radius.Packet {
|
ModifyPacket: func(p *radius.Packet) *radius.Packet {
|
||||||
p.Code = radius.CodeAccessAccept
|
p.Code = radius.CodeAccessAccept
|
||||||
|
|||||||
@ -15,6 +15,7 @@ type State struct {
|
|||||||
TLS *tls.Conn
|
TLS *tls.Conn
|
||||||
Conn *BuffConn
|
Conn *BuffConn
|
||||||
Context context.Context
|
Context context.Context
|
||||||
|
ContextCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewState() *State {
|
func NewState() *State {
|
||||||
|
|||||||
Reference in New Issue
Block a user