diff --git a/internal/outpost/radius/eap/debug/debug.go b/internal/outpost/radius/eap/debug/debug.go index 3435bfd708..4729ee6059 100644 --- a/internal/outpost/radius/eap/debug/debug.go +++ b/internal/outpost/radius/eap/debug/debug.go @@ -23,5 +23,9 @@ func DebugPacket(p *radius.Packet) { } func FormatBytes(d []byte) string { - return fmt.Sprintf("% x", d) + b := d + if len(b) > 32 { + b = b[:32] + } + return fmt.Sprintf("% x", b) } diff --git a/internal/outpost/radius/eap/tls/buff_conn.go b/internal/outpost/radius/eap/tls/buff_conn.go new file mode 100644 index 0000000000..7455763afa --- /dev/null +++ b/internal/outpost/radius/eap/tls/buff_conn.go @@ -0,0 +1,92 @@ +package tls + +import ( + "bytes" + "context" + "net" + "time" + + log "github.com/sirupsen/logrus" +) + +type BuffConn struct { + reader *bytes.Buffer + writer *bytes.Buffer + + ctx context.Context + + expectedWriterByteCount int + writtenByteCount int +} + +func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn { + c := &BuffConn{ + reader: bytes.NewBuffer(initialData), + writer: bytes.NewBuffer([]byte{}), + ctx: ctx, + } + return c +} + +func (conn BuffConn) OutboundData() []byte { + for { + // TODO cancel with conn.ctx + b := conn.writer.Bytes() + if len(b) < 1 { + log.Debug("TLS(buffcon): Attempted retrieve from empty buffer, stalling...") + time.Sleep(1 * time.Second) + continue + } + return b + } +} + +func (conn *BuffConn) UpdateData(data []byte) { + conn.reader.Write(data) + conn.writtenByteCount += len(data) + log.Debugf("TLS(buffcon): Appending new data %d (total %d, expecting %d)", len(data), conn.writtenByteCount, conn.expectedWriterByteCount) +} + +func (conn BuffConn) NeedsMoreData() bool { + if conn.expectedWriterByteCount > 0 { + return conn.reader.Len() < int(conn.expectedWriterByteCount) + } + return false +} + +func (conn *BuffConn) Read(p []byte) (int, error) { + for { + // TODO cancel with conn.ctx + n, err := conn.reader.Read(p) + if n == 0 { + log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p)) + time.Sleep(100 * time.Millisecond) + continue + } + if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { + log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) + time.Sleep(100 * time.Millisecond) + continue + } + if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { + conn.expectedWriterByteCount = 0 + } + if conn.reader.Len() == 0 { + conn.writtenByteCount = 0 + } + log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n) + return n, err + } +} + +func (conn BuffConn) Write(p []byte) (int, error) { + log.Debugf("TLS(buffcon): Write: %d", len(p)) + return conn.writer.Write(p) +} + +func (conn BuffConn) Close() error { return nil } +func (conn BuffConn) LocalAddr() net.Addr { return nil } +func (conn BuffConn) RemoteAddr() net.Addr { return nil } +func (conn BuffConn) SetDeadline(t time.Time) error { return nil } +func (conn BuffConn) SetReadDeadline(t time.Time) error { return nil } +func (conn BuffConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/internal/outpost/radius/eap/tls/conn.go b/internal/outpost/radius/eap/tls/conn.go deleted file mode 100644 index a72e2d01c0..0000000000 --- a/internal/outpost/radius/eap/tls/conn.go +++ /dev/null @@ -1,92 +0,0 @@ -package tls - -import ( - "bytes" - "context" - "net" - "time" - - log "github.com/sirupsen/logrus" -) - -type TLSConnection struct { - reader *bytes.Buffer - writer *bytes.Buffer - - ctx context.Context - - expectedWriterByteCount int - writtenByteCount int -} - -func NewTLSConnection(initialData []byte, ctx context.Context) *TLSConnection { - c := &TLSConnection{ - reader: bytes.NewBuffer(initialData), - writer: bytes.NewBuffer([]byte{}), - ctx: ctx, - } - return c -} - -func (conn TLSConnection) OutboundData() []byte { - for { - // TODO cancel with conn.ctx - b := conn.writer.Bytes() - if len(b) < 1 { - log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...") - time.Sleep(1 * time.Second) - continue - } - return b - } -} - -func (conn *TLSConnection) UpdateData(data []byte) { - conn.reader.Write(data) - conn.writtenByteCount += len(data) - log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.writtenByteCount, conn.expectedWriterByteCount) -} - -func (conn TLSConnection) NeedsMoreData() bool { - if conn.expectedWriterByteCount > 0 { - return conn.reader.Len() < int(conn.expectedWriterByteCount) - } - return false -} - -func (conn *TLSConnection) Read(p []byte) (int, error) { - for { - // TODO cancel with conn.ctx - n, err := conn.reader.Read(p) - if n == 0 { - log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p)) - time.Sleep(100 * time.Millisecond) - continue - } - if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { - log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) - time.Sleep(100 * time.Millisecond) - continue - } - if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { - conn.expectedWriterByteCount = 0 - } - if conn.reader.Len() == 0 { - conn.writtenByteCount = 0 - } - log.Debugf("TLS(buffer): Read: %d from %d", len(p), n) - return n, err - } -} - -func (conn TLSConnection) Write(p []byte) (int, error) { - log.Debugf("TLS(buffer): Write: %d", len(p)) - return conn.writer.Write(p) -} - -func (conn TLSConnection) Close() error { return nil } -func (conn TLSConnection) LocalAddr() net.Addr { return nil } -func (conn TLSConnection) RemoteAddr() net.Addr { return nil } -func (conn TLSConnection) SetDeadline(t time.Time) error { return nil } -func (conn TLSConnection) SetReadDeadline(t time.Time) error { return nil } -func (conn TLSConnection) SetWriteDeadline(t time.Time) error { return nil } diff --git a/internal/outpost/radius/eap/tls/payload.go b/internal/outpost/radius/eap/tls/payload.go index a1e71c2af9..589d10a7f1 100644 --- a/internal/outpost/radius/eap/tls/payload.go +++ b/internal/outpost/radius/eap/tls/payload.go @@ -92,7 +92,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { log.Debug("TLS: no TLS connection in state yet, starting connection") ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) st.Context = ctx - st.Conn = NewTLSConnection(p.Data, st.Context) + st.Conn = NewBuffConn(p.Data, st.Context) st.TLS = tls.Server(st.Conn, &tls.Config{ GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { log.Debugf("TLS: ClientHello: %+v\n", ch) diff --git a/internal/outpost/radius/eap/tls/state.go b/internal/outpost/radius/eap/tls/state.go index 883349a1f8..cdd03714f6 100644 --- a/internal/outpost/radius/eap/tls/state.go +++ b/internal/outpost/radius/eap/tls/state.go @@ -13,7 +13,7 @@ type State struct { MPPEKey []byte TotalPayloadSize int TLS *tls.Conn - Conn *TLSConnection + Conn *BuffConn Context context.Context }