diff --git a/Dockerfile b/Dockerfile index e39ac6f79..394601a18 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,7 @@ RUN go mod download COPY . /app RUN set -x \ - && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always)" \ + && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always 2>/dev/null || echo dev)" \ && go build \ -trimpath \ -mod=readonly \ diff --git a/mtglib/internal/tls/fake/client_side.go b/mtglib/internal/tls/fake/client_side.go index 3b7e5a052..737ab8494 100644 --- a/mtglib/internal/tls/fake/client_side.go +++ b/mtglib/internal/tls/fake/client_side.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "crypto/subtle" "encoding/binary" + "errors" "fmt" "io" "net" @@ -23,6 +24,11 @@ const ( RandomOffset = 1 + 2 + 2 + 1 + 3 + 2 sniDNSNamesListType = 0 + + // maxContinuationRecords limits the number of continuation TLS records + // that reassembleTLSHandshake will read. This prevents resource exhaustion + // from adversarial fragmentation. + maxContinuationRecords = 10 ) var ( @@ -56,12 +62,18 @@ func ReadClientHello( // 4. New digest should be all 0 except of last 4 bytes // 5. Last 4 bytes are little endian uint32 of UNIX timestamp when // this message was created. + reassembled, err := reassembleTLSHandshake(conn) + if err != nil { + return nil, fmt.Errorf("cannot reassemble TLS records: %w", err) + } + handshakeCopyBuf := &bytes.Buffer{} - reader := io.TeeReader(conn, handshakeCopyBuf) + reader := io.TeeReader(reassembled, handshakeCopyBuf) - reader, err := parseTLSHeader(reader) - if err != nil { - return nil, fmt.Errorf("cannot parse tls header: %w", err) + // Skip the TLS record header (validated during reassembly). + // The header still flows through TeeReader into handshakeCopyBuf for HMAC. + if _, err = io.CopyN(io.Discard, reader, tls.SizeHeader); err != nil { + return nil, fmt.Errorf("cannot skip tls header: %w", err) } reader, err = parseHandshakeHeader(reader) @@ -110,17 +122,30 @@ func ReadClientHello( return hello, nil } -func parseTLSHeader(r io.Reader) (io.Reader, error) { - // record_type(1) + version(2) + size(2) - // 16 - type is 0x16 (handshake record) - // 03 01 - protocol version is "3,1" (also known as TLS 1.0) - // 00 f8 - 0xF8 (248) bytes of handshake message follows - header := [1 + 2 + 2]byte{} - - if _, err := io.ReadFull(r, header[:]); err != nil { +// reassembleTLSHandshake reads one or more TLS records from conn, +// validates the record type and version, and reassembles fragmented +// handshake payloads into a single TLS record. +// +// Per RFC 5246 Section 6.2.1, handshake messages may be fragmented +// across multiple TLS records. DPI bypass tools like ByeDPI use this +// to evade censorship. +// +// The returned buffer contains the full TLS record (header + payload) +// so that callers can include the header in HMAC computation. +func reassembleTLSHandshake(conn io.Reader) (*bytes.Buffer, error) { + header := [tls.SizeHeader]byte{} + + if _, err := io.ReadFull(conn, header[:]); err != nil { return nil, fmt.Errorf("cannot read record header: %w", err) } + length := int64(binary.BigEndian.Uint16(header[3:])) + payload := &bytes.Buffer{} + + if _, err := io.CopyN(payload, conn, length); err != nil { + return nil, fmt.Errorf("cannot read record payload: %w", err) + } + if header[0] != tls.TypeHandshake { return nil, fmt.Errorf("unexpected record type %#x", header[0]) } @@ -129,12 +154,93 @@ func parseTLSHeader(r io.Reader) (io.Reader, error) { return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2]) } - length := int64(binary.BigEndian.Uint16(header[3:])) - buf := &bytes.Buffer{} + // Reassemble fragmented payload. continuationCount caps the total + // number of continuation records across both phases below. + continuationCount := 0 - _, err := io.CopyN(buf, r, length) + // Phase 1: read continuation records until we have at least the + // 4-byte handshake header (type + uint24 length) to determine the + // expected total size. + for ; payload.Len() < 4 && continuationCount < maxContinuationRecords; continuationCount++ { + prevLen := payload.Len() - return buf, err + if err := readContinuationRecord(conn, payload); err != nil { + payload.Truncate(prevLen) // discard partial data on error + + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + break // no more records — let downstream parsing handle what we have + } + + return nil, err + } + } + + // Phase 2: we know the expected handshake size — read remaining + // continuation records until the payload is complete. + if payload.Len() >= 4 { + p := payload.Bytes() + expectedTotal := 4 + (int(p[1])<<16 | int(p[2])<<8 | int(p[3])) + + if expectedTotal > 0xFFFF { + return nil, fmt.Errorf("handshake message too large: %d bytes", expectedTotal) + } + + for ; payload.Len() < expectedTotal && continuationCount < maxContinuationRecords; continuationCount++ { + if err := readContinuationRecord(conn, payload); err != nil { + return nil, err + } + } + + if payload.Len() < expectedTotal { + return nil, fmt.Errorf("cannot reassemble handshake: too many continuation records") + } + + payload.Truncate(expectedTotal) + } + + if payload.Len() > 0xFFFF { + return nil, fmt.Errorf("reassembled payload too large: %d bytes", payload.Len()) + } + + // Reconstruct a single TLS record with the reassembled payload. + result := &bytes.Buffer{} + result.Grow(tls.SizeHeader + payload.Len()) + result.Write(header[:3]) + binary.Write(result, binary.BigEndian, uint16(payload.Len())) //nolint:errcheck // bytes.Buffer.Write never fails + result.Write(payload.Bytes()) + + return result, nil +} + +// readContinuationRecord reads the next TLS record header and appends its +// full payload to dst. It returns an error if the record is not a handshake +// record. +func readContinuationRecord(conn io.Reader, dst *bytes.Buffer) error { + nextHeader := [tls.SizeHeader]byte{} + + if _, err := io.ReadFull(conn, nextHeader[:]); err != nil { + return fmt.Errorf("cannot read continuation record header: %w", err) + } + + if nextHeader[0] != tls.TypeHandshake { + return fmt.Errorf("unexpected continuation record type %#x", nextHeader[0]) + } + + if nextHeader[1] != 3 || nextHeader[2] != 1 { + return fmt.Errorf("unexpected continuation record version %#x %#x", nextHeader[1], nextHeader[2]) + } + + nextLength := int64(binary.BigEndian.Uint16(nextHeader[3:])) + + if nextLength == 0 { + return fmt.Errorf("zero-length continuation record") + } + + if _, err := io.CopyN(dst, conn, nextLength); err != nil { + return fmt.Errorf("cannot read continuation record payload: %w", err) + } + + return nil } func parseHandshakeHeader(r io.Reader) (io.Reader, error) { diff --git a/mtglib/internal/tls/fake/client_side_test.go b/mtglib/internal/tls/fake/client_side_test.go index 2e66b6c69..ed30eaa0d 100644 --- a/mtglib/internal/tls/fake/client_side_test.go +++ b/mtglib/internal/tls/fake/client_side_test.go @@ -3,8 +3,10 @@ package fake_test import ( "bytes" "encoding/binary" + "encoding/json" "errors" "io" + "os" "testing" "time" @@ -393,3 +395,273 @@ func TestParseClientHelloSNI(t *testing.T) { t.Parallel() suite.Run(t, &ParseClientHelloSNITestSuite{}) } + +// fragmentTLSRecord splits a single TLS record into n TLS records by +// dividing the payload into roughly equal parts. Each part gets its own +// TLS record header with the same record type and version. +func fragmentTLSRecord(t testing.TB, full []byte, n int) []byte { + t.Helper() + + recordType := full[0] + version := full[1:3] + payload := full[tls.SizeHeader:] + + chunkSize := len(payload) / n + result := &bytes.Buffer{} + + for i := 0; i < n; i++ { + start := i * chunkSize + end := start + chunkSize + + if i == n-1 { + end = len(payload) + } + + chunk := payload[start:end] + result.WriteByte(recordType) + result.Write(version) + require.NoError(t, binary.Write(result, binary.BigEndian, uint16(len(chunk)))) + result.Write(chunk) + } + + return result.Bytes() +} + +// splitPayloadAt creates two TLS records from a single record by splitting +// the payload at the given byte position. +func splitPayloadAt(t testing.TB, full []byte, pos int) []byte { + t.Helper() + + payload := full[tls.SizeHeader:] + buf := &bytes.Buffer{} + + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(pos))) + buf.Write(payload[:pos]) + + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(len(payload)-pos))) + buf.Write(payload[pos:]) + + return buf.Bytes() +} + +type ParseClientHelloFragmentedTestSuite struct { + suite.Suite + + secret mtglib.Secret + snapshot *clientHelloSnapshot +} + +func (s *ParseClientHelloFragmentedTestSuite) SetupSuite() { + parsed, err := mtglib.ParseSecret( + "ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d", + ) + require.NoError(s.T(), err) + + s.secret = parsed + + fileData, err := os.ReadFile("testdata/client-hello-ok-19dfe38384b9884b.json") + require.NoError(s.T(), err) + + s.snapshot = &clientHelloSnapshot{} + require.NoError(s.T(), json.Unmarshal(fileData, s.snapshot)) +} + +func (s *ParseClientHelloFragmentedTestSuite) makeConn(data []byte) *parseClientHelloConnMock { + readBuf := &bytes.Buffer{} + readBuf.Write(data) + + connMock := &parseClientHelloConnMock{ + readBuf: readBuf, + } + + connMock. + On("SetReadDeadline", mock.AnythingOfType("time.Time")). + Twice(). + Return(nil) + + return connMock +} + +func (s *ParseClientHelloFragmentedTestSuite) TestReassemblySuccess() { + full := s.snapshot.GetFull() + + tests := []struct { + name string + data []byte + }{ + {"two equal fragments", fragmentTLSRecord(s.T(), full, 2)}, + {"three equal fragments", fragmentTLSRecord(s.T(), full, 3)}, + {"single byte first fragment", splitPayloadAt(s.T(), full, 1)}, + {"three byte first fragment", splitPayloadAt(s.T(), full, 3)}, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + connMock := s.makeConn(tt.data) + defer connMock.AssertExpectations(s.T()) + + hello, err := fake.ReadClientHello( + connMock, + s.secret.Key[:], + s.secret.Host, + TolerateTime, + ) + s.Require().NoError(err) + + s.Equal(s.snapshot.GetRandom(), hello.Random[:]) + s.Equal(s.snapshot.GetSessionID(), hello.SessionID) + s.Equal(uint16(s.snapshot.CipherSuite), hello.CipherSuite) + }) + } +} + +func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() { + full := s.snapshot.GetFull() + payload := full[tls.SizeHeader:] + + tests := []struct { + name string + buildData func() []byte + errMsg string + }{ + { + name: "wrong continuation record type", + buildData: func() []byte { + buf := &bytes.Buffer{} + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10))) + buf.Write(payload[:10]) + // Wrong type: application data instead of handshake + buf.WriteByte(tls.TypeApplicationData) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10))) + buf.Write(payload[10:]) + return buf.Bytes() + }, + errMsg: "unexpected continuation record type", + }, + { + name: "too many continuation records", + buildData: func() []byte { + // Handshake header claiming 256 bytes, but we only send 1 byte per continuation + handshakePayload := []byte{0x01, 0x00, 0x01, 0x00} + buf := &bytes.Buffer{} + buf.WriteByte(tls.TypeHandshake) + buf.Write([]byte{3, 1}) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload)))) + buf.Write(handshakePayload) + for range 11 { + buf.WriteByte(tls.TypeHandshake) + buf.Write([]byte{3, 1}) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(1))) + buf.WriteByte(0xAB) + } + return buf.Bytes() + }, + errMsg: "too many continuation records", + }, + { + name: "zero-length continuation record", + buildData: func() []byte { + buf := &bytes.Buffer{} + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10))) + buf.Write(payload[:10]) + // Valid header but zero-length payload + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0))) + return buf.Bytes() + }, + errMsg: "zero-length continuation record", + }, + { + name: "wrong continuation record version", + buildData: func() []byte { + buf := &bytes.Buffer{} + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10))) + buf.Write(payload[:10]) + // Wrong version: 3.3 instead of 3.1 + buf.WriteByte(tls.TypeHandshake) + buf.Write([]byte{3, 3}) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10))) + buf.Write(payload[10:]) + return buf.Bytes() + }, + errMsg: "unexpected continuation record version", + }, + { + name: "handshake message too large", + buildData: func() []byte { + // Handshake header claiming 0x010000 (65536) bytes — exceeds 0xFFFF limit + handshakePayload := []byte{0x01, 0x01, 0x00, 0x00} + buf := &bytes.Buffer{} + buf.WriteByte(tls.TypeHandshake) + buf.Write([]byte{3, 1}) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload)))) + buf.Write(handshakePayload) + return buf.Bytes() + }, + errMsg: "handshake message too large", + }, + { + name: "truncated continuation record header", + buildData: func() []byte { + buf := &bytes.Buffer{} + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10))) + buf.Write(payload[:10]) + // Connection ends mid-header (only 2 bytes) + buf.WriteByte(tls.TypeHandshake) + buf.WriteByte(3) + return buf.Bytes() + }, + errMsg: "cannot read continuation record header", + }, + { + name: "truncated continuation record payload", + buildData: func() []byte { + buf := &bytes.Buffer{} + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10))) + buf.Write(payload[:10]) + // Claims 100 bytes but no payload follows + buf.WriteByte(tls.TypeHandshake) + buf.Write(full[1:3]) + require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100))) + return buf.Bytes() + }, + errMsg: "cannot read continuation record payload", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + connMock := s.makeConn(tt.buildData()) + defer connMock.AssertExpectations(s.T()) + + _, err := fake.ReadClientHello( + connMock, + s.secret.Key[:], + s.secret.Host, + TolerateTime, + ) + s.ErrorContains(err, tt.errMsg) + }) + } +} + +func TestParseClientHelloFragmented(t *testing.T) { + t.Parallel() + suite.Run(t, &ParseClientHelloFragmentedTestSuite{}) +}