diff --git a/mtglib/internal/tls/fake/client_side.go b/mtglib/internal/tls/fake/client_side.go index 737ab8494..66542bc8c 100644 --- a/mtglib/internal/tls/fake/client_side.go +++ b/mtglib/internal/tls/fake/client_side.go @@ -6,14 +6,11 @@ import ( "crypto/sha256" "crypto/subtle" "encoding/binary" - "errors" "fmt" "io" "net" "slices" "time" - - "github.com/9seconds/mtg/v2/mtglib/internal/tls" ) const ( @@ -24,11 +21,6 @@ const ( RandomOffset = 1 + 2 + 2 + 1 + 3 + 2 sniDNSNamesListType = 0 - - // maxContinuationRecords limits the number of continuation TLS records - // that reassembleTLSHandshake will read. This prevents resource exhaustion - // from adversarial fragmentation. - maxContinuationRecords = 10 ) var ( @@ -62,31 +54,17 @@ func ReadClientHello( // 4. New digest should be all 0 except of last 4 bytes // 5. Last 4 bytes are little endian uint32 of UNIX timestamp when // this message was created. - reassembled, err := reassembleTLSHandshake(conn) + clientHelloCopy, handshakeReader, err := parseClientHello(conn) if err != nil { - return nil, fmt.Errorf("cannot reassemble TLS records: %w", err) - } - - handshakeCopyBuf := &bytes.Buffer{} - reader := io.TeeReader(reassembled, handshakeCopyBuf) - - // Skip the TLS record header (validated during reassembly). - // The header still flows through TeeReader into handshakeCopyBuf for HMAC. - if _, err = io.CopyN(io.Discard, reader, tls.SizeHeader); err != nil { - return nil, fmt.Errorf("cannot skip tls header: %w", err) + return nil, fmt.Errorf("cannot read client hello: %w", err) } - reader, err = parseHandshakeHeader(reader) - if err != nil { - return nil, fmt.Errorf("cannot parse handshake header: %w", err) - } - - hello, err := parseHandshake(reader) + hello, err := parseHandshake(handshakeReader) if err != nil { return nil, fmt.Errorf("cannot parse handshake: %w", err) } - sniHostnames, err := parseSNI(reader) + sniHostnames, err := parseSNI(handshakeReader) if err != nil { return nil, fmt.Errorf("cannot parse SNI: %w", err) } @@ -97,10 +75,10 @@ func ReadClientHello( digest := hmac.New(sha256.New, secret) // we write a copy of the handshake with client random all nullified. - digest.Write(handshakeCopyBuf.Next(RandomOffset)) - handshakeCopyBuf.Next(RandomLen) + digest.Write(clientHelloCopy.Next(RandomOffset)) + clientHelloCopy.Next(RandomLen) digest.Write(emptyRandom[:]) - digest.Write(handshakeCopyBuf.Bytes()) + digest.Write(clientHelloCopy.Bytes()) computed := digest.Sum(nil) @@ -122,152 +100,6 @@ func ReadClientHello( return hello, nil } -// reassembleTLSHandshake reads one or more TLS records from conn, -// validates the record type and version, and reassembles fragmented -// handshake payloads into a single TLS record. -// -// Per RFC 5246 Section 6.2.1, handshake messages may be fragmented -// across multiple TLS records. DPI bypass tools like ByeDPI use this -// to evade censorship. -// -// The returned buffer contains the full TLS record (header + payload) -// so that callers can include the header in HMAC computation. -func reassembleTLSHandshake(conn io.Reader) (*bytes.Buffer, error) { - header := [tls.SizeHeader]byte{} - - if _, err := io.ReadFull(conn, header[:]); err != nil { - return nil, fmt.Errorf("cannot read record header: %w", err) - } - - length := int64(binary.BigEndian.Uint16(header[3:])) - payload := &bytes.Buffer{} - - if _, err := io.CopyN(payload, conn, length); err != nil { - return nil, fmt.Errorf("cannot read record payload: %w", err) - } - - if header[0] != tls.TypeHandshake { - return nil, fmt.Errorf("unexpected record type %#x", header[0]) - } - - if header[1] != 3 || header[2] != 1 { - return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2]) - } - - // Reassemble fragmented payload. continuationCount caps the total - // number of continuation records across both phases below. - continuationCount := 0 - - // Phase 1: read continuation records until we have at least the - // 4-byte handshake header (type + uint24 length) to determine the - // expected total size. - for ; payload.Len() < 4 && continuationCount < maxContinuationRecords; continuationCount++ { - prevLen := payload.Len() - - if err := readContinuationRecord(conn, payload); err != nil { - payload.Truncate(prevLen) // discard partial data on error - - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - break // no more records — let downstream parsing handle what we have - } - - return nil, err - } - } - - // Phase 2: we know the expected handshake size — read remaining - // continuation records until the payload is complete. - if payload.Len() >= 4 { - p := payload.Bytes() - expectedTotal := 4 + (int(p[1])<<16 | int(p[2])<<8 | int(p[3])) - - if expectedTotal > 0xFFFF { - return nil, fmt.Errorf("handshake message too large: %d bytes", expectedTotal) - } - - for ; payload.Len() < expectedTotal && continuationCount < maxContinuationRecords; continuationCount++ { - if err := readContinuationRecord(conn, payload); err != nil { - return nil, err - } - } - - if payload.Len() < expectedTotal { - return nil, fmt.Errorf("cannot reassemble handshake: too many continuation records") - } - - payload.Truncate(expectedTotal) - } - - if payload.Len() > 0xFFFF { - return nil, fmt.Errorf("reassembled payload too large: %d bytes", payload.Len()) - } - - // Reconstruct a single TLS record with the reassembled payload. - result := &bytes.Buffer{} - result.Grow(tls.SizeHeader + payload.Len()) - result.Write(header[:3]) - binary.Write(result, binary.BigEndian, uint16(payload.Len())) //nolint:errcheck // bytes.Buffer.Write never fails - result.Write(payload.Bytes()) - - return result, nil -} - -// readContinuationRecord reads the next TLS record header and appends its -// full payload to dst. It returns an error if the record is not a handshake -// record. -func readContinuationRecord(conn io.Reader, dst *bytes.Buffer) error { - nextHeader := [tls.SizeHeader]byte{} - - if _, err := io.ReadFull(conn, nextHeader[:]); err != nil { - return fmt.Errorf("cannot read continuation record header: %w", err) - } - - if nextHeader[0] != tls.TypeHandshake { - return fmt.Errorf("unexpected continuation record type %#x", nextHeader[0]) - } - - if nextHeader[1] != 3 || nextHeader[2] != 1 { - return fmt.Errorf("unexpected continuation record version %#x %#x", nextHeader[1], nextHeader[2]) - } - - nextLength := int64(binary.BigEndian.Uint16(nextHeader[3:])) - - if nextLength == 0 { - return fmt.Errorf("zero-length continuation record") - } - - if _, err := io.CopyN(dst, conn, nextLength); err != nil { - return fmt.Errorf("cannot read continuation record payload: %w", err) - } - - return nil -} - -func parseHandshakeHeader(r io.Reader) (io.Reader, error) { - // type(1) + size(3 / uint24) - // 01 - handshake message type 0x01 (client hello) - // 00 00 f4 - 0xF4 (244) bytes of client hello data follows - header := [1 + 3]byte{} - - if _, err := io.ReadFull(r, header[:]); err != nil { - return nil, fmt.Errorf("cannot read handshake header: %w", err) - } - - if header[0] != TypeHandshakeClient { - return nil, fmt.Errorf("incorrect handshake type: %#x", header[0]) - } - - // unfortunately there is not uint24 in golang, so we just reust header - header[0] = 0 - - length := int64(binary.BigEndian.Uint32(header[:])) - buf := &bytes.Buffer{} - - _, err := io.CopyN(buf, r, length) - - return buf, err -} - func parseHandshake(r io.Reader) (*ClientHello, error) { // A protocol version of "3,3" (meaning TLS 1.2) is given. header := [2]byte{} diff --git a/mtglib/internal/tls/fake/client_side_test.go b/mtglib/internal/tls/fake/client_side_test.go index ed30eaa0d..bf23409e4 100644 --- a/mtglib/internal/tls/fake/client_side_test.go +++ b/mtglib/internal/tls/fake/client_side_test.go @@ -543,7 +543,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { buf.Write(payload[10:]) return buf.Bytes() }, - errMsg: "unexpected continuation record type", + errMsg: "unexpected record type", }, { name: "too many continuation records", @@ -563,7 +563,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { } return buf.Bytes() }, - errMsg: "too many continuation records", + errMsg: "too many fragments", }, { name: "zero-length continuation record", @@ -579,7 +579,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0))) return buf.Bytes() }, - errMsg: "zero-length continuation record", + errMsg: "cannot read record header", }, { name: "wrong continuation record version", @@ -596,7 +596,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { buf.Write(payload[10:]) return buf.Bytes() }, - errMsg: "unexpected continuation record version", + errMsg: "unexpected protocol version", }, { name: "handshake message too large", @@ -610,7 +610,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { buf.Write(handshakePayload) return buf.Bytes() }, - errMsg: "handshake message too large", + errMsg: "cannot read record header", }, { name: "truncated continuation record header", @@ -625,7 +625,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { buf.WriteByte(3) return buf.Bytes() }, - errMsg: "cannot read continuation record header", + errMsg: "cannot read record header", }, { name: "truncated continuation record payload", @@ -641,7 +641,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100))) return buf.Bytes() }, - errMsg: "cannot read continuation record payload", + errMsg: "EOF", }, } diff --git a/mtglib/internal/tls/fake/utils.go b/mtglib/internal/tls/fake/utils.go new file mode 100644 index 000000000..5e9271128 --- /dev/null +++ b/mtglib/internal/tls/fake/utils.go @@ -0,0 +1,158 @@ +package fake + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/9seconds/mtg/v2/mtglib/internal/tls" +) + +const ( + maxFragmentsCount = 10 +) + +var ErrTooManyFragments = errors.New("too many fragments") + +// https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1 +// client hello can be fragmented in a series of packets: +// +// Bytes on the wire: +// +// 16 03 01 00 F8 01 00 00 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...] +// ├─────────────┤├──────────────────────────────────────────────────────────────────┤ +// +// TLS record Payload (248 bytes) +// header (5B) +// +// 16 = Handshake +// 03 01 = TLS 1.0 (record layer version) +// 00 F8 = 248 bytes follow +// +// 01 = ClientHello (handshake type) +// 00 00 F4 = 244 bytes of handshake body +// 03 03 = TLS 1.2 (actual protocol version) +// ...rest of ClientHello... +// +// Fragmented record look like: +// +// Record 1: +// +// 16 03 01 00 03 01 00 00 +// ├─────────────┤├──────┤ +// +// TLS header 3 bytes of payload +// +// 16 = Handshake +// 03 01 = TLS 1.0 +// 00 03 = only 3 bytes follow +// +// 01 = ClientHello type +// 00 00 = first 2 bytes of the uint24 length (INCOMPLETE!) +// +// Record 2: +// 16 03 01 00 F5 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...] +// ├─────────────┤├────────────────────────────────────────────────────────────┤ +// +// TLS header remaining 245 bytes of payload +// +// 16 = Handshake +// 03 01 = TLS 1.0 +// 00 F5 = 245 bytes follow +// +// F4 = last byte of uint24 length (now complete: 00 00 F4 = 244) +// 03 03 = TLS 1.2 +// ...rest of ClientHello continues... +// +// So it means that there could be a series of handshake packets of different +// lengths. The goal of this function is to concatenate these fragments. +type fragmentedHandshakeReader struct { + r io.Reader + buf bytes.Buffer + readFragments int +} + +func (f *fragmentedHandshakeReader) Read(p []byte) (int, error) { + if n, err := f.buf.Read(p); err == nil { + return n, nil + } + + f.buf.Reset() + + for f.buf.Len() == 0 { + if f.readFragments > maxFragmentsCount { + return 0, ErrTooManyFragments + } + + if err := f.parseNextFragment(); err != nil { + return 0, err + } + + f.readFragments++ + } + + return f.buf.Read(p) +} + +func (f *fragmentedHandshakeReader) parseNextFragment() error { + // record_type(1) + version(2) + size(2) + // 16 - type is 0x16 (handshake record) + // 03 01 - protocol version is "3,1" (also known as TLS 1.0) + // 00 f8 - 0xF8 (248) bytes of handshake message follows + header := [1 + 2 + 2]byte{} + + if _, err := io.ReadFull(f.r, header[:]); err != nil { + return fmt.Errorf("cannot read record header: %w", err) + } + + if header[0] != tls.TypeHandshake { + return fmt.Errorf("unexpected record type %#x", header[0]) + } + + if header[1] != 3 || header[2] != 1 { + return fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2]) + } + + length := int64(binary.BigEndian.Uint16(header[3:])) + _, err := io.CopyN(&f.buf, f.r, length) + + return err +} + +func parseClientHello(r io.Reader) (*bytes.Buffer, *bytes.Buffer, error) { + r = &fragmentedHandshakeReader{r: r} + header := [1 + 3]byte{} + + if _, err := io.ReadFull(r, header[:]); err != nil { + return nil, nil, fmt.Errorf("cannot read handshake header: %w", err) + } + + if header[0] != TypeHandshakeClient { + return nil, nil, fmt.Errorf("incorrect handshake type: %#x", header[0]) + } + + // unfortunately there is not uint24 in golang, so we just reuse header + header[0] = 0 + length := int64(binary.BigEndian.Uint32(header[:])) + + clientHelloCopy := &bytes.Buffer{} + clientHelloCopy.Write([]byte{tls.TypeHandshake, 3, 1}) + binary.Write( //nolint: errcheck + clientHelloCopy, + binary.BigEndian, + // 1 for handshake type + // 3 for handshake length + uint16(1+3+length), + ) + clientHelloCopy.WriteByte(TypeHandshakeClient) + clientHelloCopy.Write(header[1:]) + + handshakeCopy := &bytes.Buffer{} + writer := io.MultiWriter(clientHelloCopy, handshakeCopy) + + _, err := io.CopyN(writer, r, length) + + return clientHelloCopy, handshakeCopy, err +}