@ -2,6 +2,7 @@ package tls
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -12,20 +13,24 @@ type TLSConnection struct {
|
|||||||
reader *bytes.Buffer
|
reader *bytes.Buffer
|
||||||
writer *bytes.Buffer
|
writer *bytes.Buffer
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
expectedWriterByteCount int
|
expectedWriterByteCount int
|
||||||
writtenByteCount int
|
writtenByteCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTLSConnection(initialData []byte) *TLSConnection {
|
func NewTLSConnection(initialData []byte, ctx context.Context) *TLSConnection {
|
||||||
c := &TLSConnection{
|
c := &TLSConnection{
|
||||||
reader: bytes.NewBuffer(initialData),
|
reader: bytes.NewBuffer(initialData),
|
||||||
writer: bytes.NewBuffer([]byte{}),
|
writer: bytes.NewBuffer([]byte{}),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn TLSConnection) OutboundData() []byte {
|
func (conn TLSConnection) OutboundData() []byte {
|
||||||
for {
|
for {
|
||||||
|
// TODO cancel with conn.ctx
|
||||||
b := conn.writer.Bytes()
|
b := conn.writer.Bytes()
|
||||||
if len(b) < 1 {
|
if len(b) < 1 {
|
||||||
log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...")
|
log.Debug("TLS(buffer): Attempted retrieve from empty buffer, stalling...")
|
||||||
@ -51,6 +56,7 @@ func (conn TLSConnection) NeedsMoreData() bool {
|
|||||||
|
|
||||||
func (conn *TLSConnection) Read(p []byte) (int, error) {
|
func (conn *TLSConnection) Read(p []byte) (int, error) {
|
||||||
for {
|
for {
|
||||||
|
// TODO cancel with conn.ctx
|
||||||
n, err := conn.reader.Read(p)
|
n, err := conn.reader.Read(p)
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p))
|
log.Debugf("TLS(buffer): Attempted read %d from empty buffer, stalling...", len(p))
|
||||||
|
@ -16,6 +16,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const maxChunkSize = 1000
|
const maxChunkSize = 1000
|
||||||
|
const staleConnectionTimeout = 10
|
||||||
|
|
||||||
|
var certs = []tls.Certificate{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Testing
|
||||||
|
cert, err := tls.LoadX509KeyPair(
|
||||||
|
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
|
||||||
|
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
|
||||||
type Payload struct {
|
type Payload struct {
|
||||||
Flags Flag
|
Flags Flag
|
||||||
@ -58,20 +73,6 @@ func (p *Payload) Encode() ([]byte, error) {
|
|||||||
return buff, nil
|
return buff, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var certs = []tls.Certificate{}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Testing
|
|
||||||
cert, err := tls.LoadX509KeyPair(
|
|
||||||
"../t/ca/out/cert_jens-mbp.lab.beryju.org.pem",
|
|
||||||
"../t/ca/out/cert_jens-mbp.lab.beryju.org.key",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
certs = append(certs, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
||||||
if stt == nil {
|
if stt == nil {
|
||||||
log.Debug("TLS: new state")
|
log.Debug("TLS: new state")
|
||||||
@ -88,7 +89,9 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
|||||||
|
|
||||||
if st.TLS == nil {
|
if st.TLS == nil {
|
||||||
log.Debug("TLS: no TLS connection in state yet, starting connection")
|
log.Debug("TLS: no TLS connection in state yet, starting connection")
|
||||||
st.Conn = NewTLSConnection(p.Data)
|
ctx, cancel := context.WithTimeout(context.Background(), staleConnectionTimeout*time.Second)
|
||||||
|
st.Context = ctx
|
||||||
|
st.Conn = NewTLSConnection(p.Data, st.Context)
|
||||||
st.TLS = tls.Server(st.Conn, &tls.Config{
|
st.TLS = tls.Server(st.Conn, &tls.Config{
|
||||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
log.Debugf("TLS: ClientHello: %+v\n", ch)
|
log.Debugf("TLS: ClientHello: %+v\n", ch)
|
||||||
@ -98,8 +101,6 @@ func (p *Payload) Handle(stt any) (protocol.Payload, *State) {
|
|||||||
ClientAuth: tls.RequireAnyClientCert,
|
ClientAuth: tls.RequireAnyClientCert,
|
||||||
Certificates: certs,
|
Certificates: certs,
|
||||||
})
|
})
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
st.Context = ctx
|
|
||||||
go func() {
|
go func() {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := st.TLS.HandshakeContext(st.Context)
|
err := st.TLS.HandshakeContext(st.Context)
|
||||||
@ -171,11 +172,11 @@ func (p *Payload) handshakeFinished(st *State) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) {
|
func (p *Payload) startChunkedTransfer(data []byte, st *State) (*Payload, *State) {
|
||||||
flags := FlagNone
|
flags := FlagLengthIncluded
|
||||||
var dataToSend []byte
|
var dataToSend []byte
|
||||||
if len(data) > maxChunkSize {
|
if len(data) > maxChunkSize {
|
||||||
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
|
log.WithField("length", len(data)).Debug("TLS: Data needs to be chunked")
|
||||||
flags += FlagMoreFragments + FlagLengthIncluded
|
flags += FlagMoreFragments
|
||||||
// Chunk data into correct chunks and add them to the list
|
// Chunk data into correct chunks and add them to the list
|
||||||
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...)
|
st.RemainingChunks = append(st.RemainingChunks, slices.Collect(slices.Chunk(data, maxChunkSize))...)
|
||||||
dataToSend = st.RemainingChunks[0]
|
dataToSend = st.RemainingChunks[0]
|
||||||
|
Reference in New Issue
Block a user