Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 7 additions & 175 deletions mtglib/internal/tls/fake/client_side.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@ import (
"crypto/sha256"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"slices"
"time"

"github.com/9seconds/mtg/v2/mtglib/internal/tls"
)

const (
Expand All @@ -24,11 +21,6 @@ const (
RandomOffset = 1 + 2 + 2 + 1 + 3 + 2

sniDNSNamesListType = 0

// maxContinuationRecords limits the number of continuation TLS records
// that reassembleTLSHandshake will read. This prevents resource exhaustion
// from adversarial fragmentation.
maxContinuationRecords = 10
)

var (
Expand Down Expand Up @@ -62,31 +54,17 @@ func ReadClientHello(
// 4. New digest should be all 0 except of last 4 bytes
// 5. Last 4 bytes are little endian uint32 of UNIX timestamp when
// this message was created.
reassembled, err := reassembleTLSHandshake(conn)
clientHelloCopy, handshakeReader, err := parseClientHello(conn)
if err != nil {
return nil, fmt.Errorf("cannot reassemble TLS records: %w", err)
}

handshakeCopyBuf := &bytes.Buffer{}
reader := io.TeeReader(reassembled, handshakeCopyBuf)

// Skip the TLS record header (validated during reassembly).
// The header still flows through TeeReader into handshakeCopyBuf for HMAC.
if _, err = io.CopyN(io.Discard, reader, tls.SizeHeader); err != nil {
return nil, fmt.Errorf("cannot skip tls header: %w", err)
return nil, fmt.Errorf("cannot read client hello: %w", err)
}

reader, err = parseHandshakeHeader(reader)
if err != nil {
return nil, fmt.Errorf("cannot parse handshake header: %w", err)
}

hello, err := parseHandshake(reader)
hello, err := parseHandshake(handshakeReader)
if err != nil {
return nil, fmt.Errorf("cannot parse handshake: %w", err)
}

sniHostnames, err := parseSNI(reader)
sniHostnames, err := parseSNI(handshakeReader)
if err != nil {
return nil, fmt.Errorf("cannot parse SNI: %w", err)
}
Expand All @@ -97,10 +75,10 @@ func ReadClientHello(

digest := hmac.New(sha256.New, secret)
// we write a copy of the handshake with client random all nullified.
digest.Write(handshakeCopyBuf.Next(RandomOffset))
handshakeCopyBuf.Next(RandomLen)
digest.Write(clientHelloCopy.Next(RandomOffset))
clientHelloCopy.Next(RandomLen)
digest.Write(emptyRandom[:])
digest.Write(handshakeCopyBuf.Bytes())
digest.Write(clientHelloCopy.Bytes())

computed := digest.Sum(nil)

Expand All @@ -122,152 +100,6 @@ func ReadClientHello(
return hello, nil
}

// reassembleTLSHandshake reads one or more TLS records from conn,
// validates the record type and version, and reassembles fragmented
// handshake payloads into a single TLS record.
//
// Per RFC 5246 Section 6.2.1, handshake messages may be fragmented
// across multiple TLS records. DPI bypass tools like ByeDPI use this
// to evade censorship.
//
// The returned buffer contains the full TLS record (header + payload)
// so that callers can include the header in HMAC computation.
func reassembleTLSHandshake(conn io.Reader) (*bytes.Buffer, error) {
header := [tls.SizeHeader]byte{}

if _, err := io.ReadFull(conn, header[:]); err != nil {
return nil, fmt.Errorf("cannot read record header: %w", err)
}

length := int64(binary.BigEndian.Uint16(header[3:]))
payload := &bytes.Buffer{}

if _, err := io.CopyN(payload, conn, length); err != nil {
return nil, fmt.Errorf("cannot read record payload: %w", err)
}

if header[0] != tls.TypeHandshake {
return nil, fmt.Errorf("unexpected record type %#x", header[0])
}

if header[1] != 3 || header[2] != 1 {
return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
}

// Reassemble fragmented payload. continuationCount caps the total
// number of continuation records across both phases below.
continuationCount := 0

// Phase 1: read continuation records until we have at least the
// 4-byte handshake header (type + uint24 length) to determine the
// expected total size.
for ; payload.Len() < 4 && continuationCount < maxContinuationRecords; continuationCount++ {
prevLen := payload.Len()

if err := readContinuationRecord(conn, payload); err != nil {
payload.Truncate(prevLen) // discard partial data on error

if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break // no more records — let downstream parsing handle what we have
}

return nil, err
}
}

// Phase 2: we know the expected handshake size — read remaining
// continuation records until the payload is complete.
if payload.Len() >= 4 {
p := payload.Bytes()
expectedTotal := 4 + (int(p[1])<<16 | int(p[2])<<8 | int(p[3]))

if expectedTotal > 0xFFFF {
return nil, fmt.Errorf("handshake message too large: %d bytes", expectedTotal)
}

for ; payload.Len() < expectedTotal && continuationCount < maxContinuationRecords; continuationCount++ {
if err := readContinuationRecord(conn, payload); err != nil {
return nil, err
}
}

if payload.Len() < expectedTotal {
return nil, fmt.Errorf("cannot reassemble handshake: too many continuation records")
}

payload.Truncate(expectedTotal)
}

if payload.Len() > 0xFFFF {
return nil, fmt.Errorf("reassembled payload too large: %d bytes", payload.Len())
}

// Reconstruct a single TLS record with the reassembled payload.
result := &bytes.Buffer{}
result.Grow(tls.SizeHeader + payload.Len())
result.Write(header[:3])
binary.Write(result, binary.BigEndian, uint16(payload.Len())) //nolint:errcheck // bytes.Buffer.Write never fails
result.Write(payload.Bytes())

return result, nil
}

// readContinuationRecord reads the next TLS record header and appends its
// full payload to dst. It returns an error if the record is not a handshake
// record.
func readContinuationRecord(conn io.Reader, dst *bytes.Buffer) error {
nextHeader := [tls.SizeHeader]byte{}

if _, err := io.ReadFull(conn, nextHeader[:]); err != nil {
return fmt.Errorf("cannot read continuation record header: %w", err)
}

if nextHeader[0] != tls.TypeHandshake {
return fmt.Errorf("unexpected continuation record type %#x", nextHeader[0])
}

if nextHeader[1] != 3 || nextHeader[2] != 1 {
return fmt.Errorf("unexpected continuation record version %#x %#x", nextHeader[1], nextHeader[2])
}

nextLength := int64(binary.BigEndian.Uint16(nextHeader[3:]))

if nextLength == 0 {
return fmt.Errorf("zero-length continuation record")
}

if _, err := io.CopyN(dst, conn, nextLength); err != nil {
return fmt.Errorf("cannot read continuation record payload: %w", err)
}

return nil
}

func parseHandshakeHeader(r io.Reader) (io.Reader, error) {
// type(1) + size(3 / uint24)
// 01 - handshake message type 0x01 (client hello)
// 00 00 f4 - 0xF4 (244) bytes of client hello data follows
header := [1 + 3]byte{}

if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, fmt.Errorf("cannot read handshake header: %w", err)
}

if header[0] != TypeHandshakeClient {
return nil, fmt.Errorf("incorrect handshake type: %#x", header[0])
}

// unfortunately there is not uint24 in golang, so we just reust header
header[0] = 0

length := int64(binary.BigEndian.Uint32(header[:]))
buf := &bytes.Buffer{}

_, err := io.CopyN(buf, r, length)

return buf, err
}

func parseHandshake(r io.Reader) (*ClientHello, error) {
// A protocol version of "3,3" (meaning TLS 1.2) is given.
header := [2]byte{}
Expand Down
14 changes: 7 additions & 7 deletions mtglib/internal/tls/fake/client_side_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
buf.Write(payload[10:])
return buf.Bytes()
},
errMsg: "unexpected continuation record type",
errMsg: "unexpected record type",
},
{
name: "too many continuation records",
Expand All @@ -563,7 +563,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
}
return buf.Bytes()
},
errMsg: "too many continuation records",
errMsg: "too many fragments",
},
{
name: "zero-length continuation record",
Expand All @@ -579,7 +579,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0)))
return buf.Bytes()
},
errMsg: "zero-length continuation record",
errMsg: "cannot read record header",
},
{
name: "wrong continuation record version",
Expand All @@ -596,7 +596,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
buf.Write(payload[10:])
return buf.Bytes()
},
errMsg: "unexpected continuation record version",
errMsg: "unexpected protocol version",
},
{
name: "handshake message too large",
Expand All @@ -610,7 +610,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
buf.Write(handshakePayload)
return buf.Bytes()
},
errMsg: "handshake message too large",
errMsg: "cannot read record header",
},
{
name: "truncated continuation record header",
Expand All @@ -625,7 +625,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
buf.WriteByte(3)
return buf.Bytes()
},
errMsg: "cannot read continuation record header",
errMsg: "cannot read record header",
},
{
name: "truncated continuation record payload",
Expand All @@ -641,7 +641,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100)))
return buf.Bytes()
},
errMsg: "cannot read continuation record payload",
errMsg: "EOF",
},
}

Expand Down
Loading
Loading