| @ -23,5 +23,9 @@ func DebugPacket(p *radius.Packet) { | ||||
| } | ||||
|  | ||||
| func FormatBytes(d []byte) string { | ||||
| 	return fmt.Sprintf("% x", d) | ||||
| 	b := d | ||||
| 	if len(b) > 32 { | ||||
| 		b = b[:32] | ||||
| 	} | ||||
| 	return fmt.Sprintf("% x", b) | ||||
| } | ||||
|  | ||||
							
								
								
									
										92
									
								
								internal/outpost/radius/eap/tls/buff_conn.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								internal/outpost/radius/eap/tls/buff_conn.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,92 @@ | ||||
| package tls | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"net" | ||||
| 	"time" | ||||
|  | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| ) | ||||
|  | ||||
| type BuffConn struct { | ||||
| 	reader *bytes.Buffer | ||||
| 	writer *bytes.Buffer | ||||
|  | ||||
| 	ctx context.Context | ||||
|  | ||||
| 	expectedWriterByteCount int | ||||
| 	writtenByteCount        int | ||||
| } | ||||
|  | ||||
| func NewBuffConn(initialData []byte, ctx context.Context) *BuffConn { | ||||
| 	c := &BuffConn{ | ||||
| 		reader: bytes.NewBuffer(initialData), | ||||
| 		writer: bytes.NewBuffer([]byte{}), | ||||
| 		ctx:    ctx, | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| func (conn BuffConn) OutboundData() []byte { | ||||
| 	for { | ||||
| 		// TODO cancel with conn.ctx | ||||
| 		b := conn.writer.Bytes() | ||||
| 		if len(b) < 1 { | ||||
| 			log.Debug("TLS(buffcon): Attempted retrieve from empty buffer, stalling...") | ||||
| 			time.Sleep(1 * time.Second) | ||||
| 			continue | ||||
| 		} | ||||
| 		return b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (conn *BuffConn) UpdateData(data []byte) { | ||||
| 	conn.reader.Write(data) | ||||
| 	conn.writtenByteCount += len(data) | ||||
| 	log.Debugf("TLS(buffcon): Appending new data %d (total %d, expecting %d)", len(data), conn.writtenByteCount, conn.expectedWriterByteCount) | ||||
| } | ||||
|  | ||||
| func (conn BuffConn) NeedsMoreData() bool { | ||||
| 	if conn.expectedWriterByteCount > 0 { | ||||
| 		return conn.reader.Len() < int(conn.expectedWriterByteCount) | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func (conn *BuffConn) Read(p []byte) (int, error) { | ||||
| 	for { | ||||
| 		// TODO cancel with conn.ctx | ||||
| 		n, err := conn.reader.Read(p) | ||||
| 		if n == 0 { | ||||
| 			log.Debugf("TLS(buffcon): Attempted read %d from empty buffer, stalling...", len(p)) | ||||
| 			time.Sleep(100 * time.Millisecond) | ||||
| 			continue | ||||
| 		} | ||||
| 		if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { | ||||
| 			log.Debugf("TLS(buffcon): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) | ||||
| 			time.Sleep(100 * time.Millisecond) | ||||
| 			continue | ||||
| 		} | ||||
| 		if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { | ||||
| 			conn.expectedWriterByteCount = 0 | ||||
| 		} | ||||
| 		if conn.reader.Len() == 0 { | ||||
| 			conn.writtenByteCount = 0 | ||||
| 		} | ||||
| 		log.Debugf("TLS(buffcon): Read: %d from %d", len(p), n) | ||||
| 		return n, err | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (conn BuffConn) Write(p []byte) (int, error) { | ||||
| 	log.Debugf("TLS(buffcon): Write: %d", len(p)) | ||||
| 	return conn.writer.Write(p) | ||||
| } | ||||
|  | ||||
| func (conn BuffConn) Close() error                       { return nil } | ||||
| func (conn BuffConn) LocalAddr() net.Addr                { return nil } | ||||
| func (conn BuffConn) RemoteAddr() net.Addr               { return nil } | ||||
| func (conn BuffConn) SetDeadline(t time.Time) error      { return nil } | ||||
| func (conn BuffConn) SetReadDeadline(t time.Time) error  { return nil } | ||||
| func (conn BuffConn) SetWriteDeadline(t time.Time) error { return nil } | ||||
| @ -1,92 +0,0 @@ | ||||
| package tls | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"net" | ||||
| 	"time" | ||||
|  | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| ) | ||||
|  | ||||
| type TLSConnection struct { | ||||
| 	reader *bytes.Buffer | ||||
| 	writer *bytes.Buffer | ||||
|  | ||||
| 	ctx context.Context | ||||
|  | ||||
| 	expectedWriterByteCount int | ||||
| 	writtenByteCount        int | ||||
| } | ||||
|  | ||||
| func NewTLSConnection(initialData []byte, ctx context.Context) *TLSConnection { | ||||
| 	c := &TLSConnection{ | ||||
| 		reader: bytes.NewBuffer(initialData), | ||||
| 		writer: bytes.NewBuffer([]byte{}), | ||||
| 		ctx:    ctx, | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| func (conn TLSConnection) OutboundData() []byte { | ||||
| 	for { | ||||
| 		// TODO cancel with conn.ctx | ||||
| 		b := conn.writer.Bytes() | ||||
| 		if len(b) < 1 { | ||||
| 			log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...") | ||||
| 			time.Sleep(1 * time.Second) | ||||
| 			continue | ||||
| 		} | ||||
| 		return b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (conn *TLSConnection) UpdateData(data []byte) { | ||||
| 	conn.reader.Write(data) | ||||
| 	conn.writtenByteCount += len(data) | ||||
| 	log.Debugf("TLS(buffer): Appending new data %d (total %d, expecting %d)", len(data), conn.writtenByteCount, conn.expectedWriterByteCount) | ||||
| } | ||||
|  | ||||
| func (conn TLSConnection) NeedsMoreData() bool { | ||||
| 	if conn.expectedWriterByteCount > 0 { | ||||
| 		return conn.reader.Len() < int(conn.expectedWriterByteCount) | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func (conn *TLSConnection) Read(p []byte) (int, error) { | ||||
| 	for { | ||||
| 		// TODO cancel with conn.ctx | ||||
| 		n, err := conn.reader.Read(p) | ||||
| 		if n == 0 { | ||||
| 			log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p)) | ||||
| 			time.Sleep(100 * time.Millisecond) | ||||
| 			continue | ||||
| 		} | ||||
| 		if conn.expectedWriterByteCount > 0 && conn.writtenByteCount < int(conn.expectedWriterByteCount) { | ||||
| 			log.Debugf("TLS(buffer): Attempted read %d while waiting for bytes %d, stalling...", len(p), conn.expectedWriterByteCount-conn.reader.Len()) | ||||
| 			time.Sleep(100 * time.Millisecond) | ||||
| 			continue | ||||
| 		} | ||||
| 		if conn.expectedWriterByteCount > 0 && conn.writtenByteCount == int(conn.expectedWriterByteCount) { | ||||
| 			conn.expectedWriterByteCount = 0 | ||||
| 		} | ||||
| 		if conn.reader.Len() == 0 { | ||||
| 			conn.writtenByteCount = 0 | ||||
| 		} | ||||
| 		log.Debugf("TLS(buffer): Read: %d from %d", len(p), n) | ||||
| 		return n, err | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (conn TLSConnection) Write(p []byte) (int, error) { | ||||
| 	log.Debugf("TLS(buffer): Write: %d", len(p)) | ||||
| 	return conn.writer.Write(p) | ||||
| } | ||||
|  | ||||
| func (conn TLSConnection) Close() error                       { return nil } | ||||
| func (conn TLSConnection) LocalAddr() net.Addr                { return nil } | ||||
| func (conn TLSConnection) RemoteAddr() net.Addr               { return nil } | ||||
| func (conn TLSConnection) SetDeadline(t time.Time) error      { return nil } | ||||
| func (conn TLSConnection) SetReadDeadline(t time.Time) error  { return nil } | ||||
| func (conn TLSConnection) SetWriteDeadline(t time.Time) error { return nil } | ||||
| @ -92,7 +92,7 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) { | ||||
| 		log.Debug("TLS: no TLS connection in state yet, starting connection") | ||||
| 		ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second) | ||||
| 		st.Context = ctx | ||||
| 		st.Conn = NewTLSConnection(p.Data, st.Context) | ||||
| 		st.Conn = NewBuffConn(p.Data, st.Context) | ||||
| 		st.TLS = tls.Server(st.Conn, &tls.Config{ | ||||
| 			GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { | ||||
| 				log.Debugf("TLS: ClientHello: %+v\n", ch) | ||||
|  | ||||
| @ -13,7 +13,7 @@ type State struct { | ||||
| 	MPPEKey          []byte | ||||
| 	TotalPayloadSize int | ||||
| 	TLS              *tls.Conn | ||||
| 	Conn             *TLSConnection | ||||
| 	Conn             *BuffConn | ||||
| 	Context          context.Context | ||||
| } | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens Langhammer
					Jens Langhammer