diff --git a/.github/workflows/skit.yml b/.github/workflows/skit.yml index 4b15ead6..ab599d2c 100644 --- a/.github/workflows/skit.yml +++ b/.github/workflows/skit.yml @@ -36,6 +36,9 @@ jobs: bun install --frozen-lockfile bun run build + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libvpx-dev + - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: @@ -89,6 +92,9 @@ jobs: bun install --frozen-lockfile bun run build + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libvpx-dev + - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: @@ -129,6 +135,9 @@ jobs: bun install --frozen-lockfile bun run build + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libvpx-dev + - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: diff --git a/AGENTS.md b/AGENTS.md index 08e1452c..5472c084 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -27,6 +27,111 @@ Agent-assisted contributions are welcome, but should be **supervised** and **rev - Follow `CONTRIBUTING.md` (DCO sign-off, Conventional Commits, SPDX headers where applicable). - **Linting discipline**: Do not blindly suppress lint warnings or errors with ignore/exception rules. Instead, consider refactoring or improving the code to address the underlying issue. If an exception is truly necessary, it **must** include a comment explaining the rationale. +## Running E2E tests + +End-to-end tests live in `e2e/` and use Playwright (Chromium, headless). + +1. **Build the UI** and **start the server** in one terminal: + + ```bash + just build-ui && SK_SERVER__MOQ_GATEWAY_URL=http://127.0.0.1:4545/moq SK_SERVER__ADDRESS=127.0.0.1:4545 just skit + ``` + +2. **Run the tests** in a second terminal: + + ```bash + just e2e-external http://localhost:4545 + ``` + +### Headless-browser pitfalls + +- Playwright runs headless Chromium with a default 1280×720 viewport. + Elements rendered below the fold are **not visible** to + `IntersectionObserver`. If a test relies on an element being observed + (e.g. the `` used by the MoQ video renderer), scroll it into + view first: + + ```ts + const canvas = page.locator('canvas'); + await canvas.scrollIntoViewIfNeeded(); + ``` + +- The `@moq/watch` `Video.Renderer` enables the `Video.Decoder` (and + therefore the `video/data` MoQ subscription) **only** when the canvas is + intersecting. Forgetting to scroll will result in a permanently black + canvas. + +## Render performance profiling + +StreamKit ships a two-layer profiling infrastructure for detecting render +regressions — particularly **cascade re-renders** where a slider interaction +(opacity, rotation) triggers expensive re-renders in unrelated memoized +components (`UnifiedLayerList`, `OpacityControl`, `RotationControl`, etc.). + +### When to use this + +- **After touching compositor hooks or components** (`useCompositorLayers`, + `CompositorNode`, or any `React.memo`'d sub-component): run the perf tests + to verify you haven't broken memoization barriers. +- **When optimising render performance**: use the baseline comparison to + measure before/after render counts and durations. +- **In CI**: Layer 1 tests run automatically via `just perf-ui` and will fail + if render counts regress beyond the 2σ threshold stored in the baseline. + +### Layer 1 — Component-level regression tests (Vitest) + +Fast, deterministic tests that measure hook/component render counts in +happy-dom. No browser required. + +```bash +just perf-ui # runs all *.perf.test.* files +``` + +Key files: + +| File | Purpose | +|------|---------| +| `ui/src/test/perf/measure.ts` | `measureRenders()` (components) and `measureHookRenders()` (hooks) | +| `ui/src/test/perf/compare.ts` | Baseline read/write, 2σ comparison, report formatting | +| `ui/src/hooks/useCompositorLayers.render-perf.test.ts` | Cascade re-render regression tests | +| `perf-baselines.json` (repo root) | Baseline snapshot — committed to track regressions over time | + +**Cascade detection pattern**: the render-perf tests simulate rapid slider +drags (20 ticks of opacity/rotation) and assert that total render count stays +within a budget (currently ≤ 30). If callback references become unstable +(e.g. `layers` array in deps instead of `selectedLayerKind`), React.memo +barriers break and the render count will blow past the budget, failing the +test. + +### Layer 2 — Interaction-level profiling (Playwright + React.Profiler) + +Real-browser profiling for dev builds. Components wrapped with +`React.Profiler` push metrics to `window.__PERF_DATA__` which Playwright +tests can read via `page.evaluate()`. + +```bash +just perf-e2e # requires: just skit + just ui (dev server at :3045) +``` + +Key files: + +| File | Purpose | +|------|---------| +| `ui/src/perf/profiler.ts` | Dev-only `PerfProfiler` wrapper + `window.__PERF_DATA__` store | +| `e2e/tests/perf-helpers.ts` | `capturePerfData()` / `resetPerfData()` Playwright utilities | +| `e2e/tests/compositor-perf.spec.ts` | E2E test: creates PiP session, drags all sliders, asserts render budget | + +Use Layer 2 when you need real paint/layout timing or want to profile +interactions end-to-end with actual browser rendering. + +### Updating the baseline + +Run `just perf-ui` — the last test in the render-perf suite writes a fresh +`perf-baselines.json` (gated behind `UPDATE_PERF_BASELINE=1`, which the +`test:perf` script sets automatically). Regular `just test-ui` runs compare +against the baseline but never overwrite it. Commit the updated baseline +alongside your changes so future runs compare against the new numbers. + ## Docker notes - Official images are built from `Dockerfile` (CPU) and `Dockerfile.gpu` (GPU-tagged) via `.github/workflows/docker.yml`. diff --git a/Cargo.lock b/Cargo.lock index 5cc7091f..4f24fab1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -159,6 +159,12 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -204,18 +210,6 @@ dependencies = [ "syn", ] -[[package]] -name = "async-channel" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2" -dependencies = [ - "concurrent-queue", - "event-listener-strategy", - "futures-core", - "pin-project-lite", -] - [[package]] name = "async-lock" version = "3.4.2" @@ -504,6 +498,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + [[package]] name = "bytes" version = "1.11.1" @@ -1244,6 +1244,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "env-libvpx-sys" +version = "5.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26ecdc636a02003406cc821aa9d703c888a966a3fd9bbdae9f7cf27d71720147" +dependencies = [ + "pkg-config", +] + [[package]] name = "equator" version = "0.4.2" @@ -1293,7 +1302,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" dependencies = [ "concurrent-queue", - "parking", "pin-project-lite", ] @@ -1348,6 +1356,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + [[package]] name = "figment" version = "0.10.19" @@ -1425,6 +1442,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" +[[package]] +name = "fontdue" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e57e16b3fe8ff4364c0661fdaac543fb38b29ea9bc9c2f45612d90adf931d2b" +dependencies = [ + "hashbrown 0.15.5", + "ttf-parser", +] + [[package]] name = "foreign-types" version = "0.3.2" @@ -1654,9 +1681,9 @@ dependencies = [ [[package]] name = "hang" -version = "0.13.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05205d948a355a7b260b82a652d64bc502a4ae42d523e25a6931220fdd5f0a0" +checksum = "10631856b75596bcfca1c82c654fc006d1814d6eded004f6b170bcbddfe4bc97" dependencies = [ "buf-list", "bytes", @@ -1687,6 +1714,8 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ + "allocator-api2", + "equivalent", "foldhash 0.1.5", "serde", ] @@ -2047,6 +2076,21 @@ dependencies = [ "version_check", ] +[[package]] +name = "image" +version = "0.25.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a" +dependencies = [ + "bytemuck", + "byteorder-lite", + "moxcms", + "num-traits", + "png", + "zune-core", + "zune-jpeg", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -2469,11 +2513,10 @@ dependencies = [ [[package]] name = "moq-lite" -version = "0.13.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4462d2c8aabd99cd8be3f6704ff92b37446cdd5c576997cae5d74fda41705834" +checksum = "22e2b5cbf085a08eca461384c1242a437c4505f41a315e50eec10d538bb910b1" dependencies = [ - "async-channel", "bytes", "futures", "hex", @@ -2489,9 +2532,9 @@ dependencies = [ [[package]] name = "moq-native" -version = "0.12.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d5a0aa1d9d584fd867f8e107d7e487cccc9eef47f27648baf2ebc3731c75706" +checksum = "5754d77ace0048c8349d76a2f12f25a093d6f9358e7e24cf86f84202c0d2a547" dependencies = [ "anyhow", "clap", @@ -2516,27 +2559,18 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "web-transport-quinn 0.10.2", + "web-transport-quinn", "web-transport-ws", ] [[package]] -name = "moq-transport" -version = "0.12.2" +name = "moxcms" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d2f896962af0634a5b71f274a07590fbbe21f30c89d986066479078644b477" +checksum = "ac9557c559cd6fc9867e122e20d2cbefc9ca29d80d027a8e39310920ed2f0a97" dependencies = [ - "bytes", - "futures", - "log", - "paste", - "serde", - "serde_json", - "serde_with", - "thiserror 1.0.69", - "tokio", - "uuid", - "web-transport", + "num-traits", + "pxfm", ] [[package]] @@ -2958,12 +2992,6 @@ dependencies = [ "audiopus_sys", ] -[[package]] -name = "parking" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" - [[package]] name = "parking_lot" version = "0.12.5" @@ -3072,6 +3100,19 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "png" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" +dependencies = [ + "bitflags 2.10.0", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "postcard" version = "1.1.3" @@ -3329,6 +3370,15 @@ dependencies = [ "syn", ] +[[package]] +name = "pxfm" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8" +dependencies = [ + "num-traits", +] + [[package]] name = "quick-xml" version = "0.26.0" @@ -4474,7 +4524,7 @@ checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" [[package]] name = "streamkit-api" -version = "0.1.1" +version = "0.2.0" dependencies = [ "indexmap 2.13.0", "serde", @@ -4512,7 +4562,7 @@ dependencies = [ [[package]] name = "streamkit-core" -version = "0.1.0" +version = "0.2.0" dependencies = [ "async-trait", "base64 0.22.1", @@ -4530,10 +4580,11 @@ dependencies = [ [[package]] name = "streamkit-engine" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bytes", "futures", + "indexmap 2.13.0", "opentelemetry", "serde", "serde-saphyr 0.0.15", @@ -4550,21 +4601,25 @@ dependencies = [ [[package]] name = "streamkit-nodes" -version = "0.1.0" +version = "0.2.0" dependencies = [ "async-trait", "audioadapter-buffers", "axum", + "base64 0.22.1", "bytes", + "env-libvpx-sys", + "fontdue", "futures", "futures-util", "hang", + "image", "moq-lite", "moq-native", - "moq-transport", "ogg", "opentelemetry", "opus", + "rayon", "reqwest 0.13.2", "rquickjs", "rubato", @@ -4572,9 +4627,11 @@ dependencies = [ "serde", "serde-saphyr 0.0.15", "serde_json", + "smallvec 1.15.1", "streamkit-core", "symphonia", "tempfile", + "tiny-skia", "tokio", "tokio-util 0.7.18", "tower", @@ -4587,7 +4644,7 @@ dependencies = [ [[package]] name = "streamkit-plugin-native" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "async-trait", @@ -4603,7 +4660,7 @@ dependencies = [ [[package]] name = "streamkit-plugin-sdk-native" -version = "0.1.0" +version = "0.2.0" dependencies = [ "async-trait", "bytes", @@ -4622,7 +4679,7 @@ dependencies = [ [[package]] name = "streamkit-plugin-wasm" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "async-trait", @@ -4720,6 +4777,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" +[[package]] +name = "strict-num" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6637bab7722d379c8b41ba849228d680cc12d0a45ba1fa2b48f2a30577a06731" + [[package]] name = "strsim" version = "0.11.1" @@ -5102,6 +5165,32 @@ dependencies = [ "time-core", ] +[[package]] +name = "tiny-skia" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47ffee5eaaf5527f630fb0e356b90ebdec84d5d18d937c5e440350f88c5a91ea" +dependencies = [ + "arrayref", + "arrayvec", + "bytemuck", + "cfg-if", + "log", + "png", + "tiny-skia-path", +] + +[[package]] +name = "tiny-skia-path" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca365c3faccca67d06593c5980fa6c57687de727a03131735bb85f01fdeeb9" +dependencies = [ + "arrayref", + "bytemuck", + "strict-num", +] + [[package]] name = "tinystr" version = "0.8.2" @@ -5589,6 +5678,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "ttf-parser" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c591d83f69777866b9126b24c6dd9a18351f177e49d625920d19f989fd31cf8" + [[package]] name = "tungstenite" version = "0.24.0" @@ -6287,9 +6382,9 @@ dependencies = [ [[package]] name = "web-async" -version = "0.1.1" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6b2260b739b0e95cf9b78f22a64704af7ed9760ea12baa3745b4b97899dc89a" +checksum = "f5414b65d9a5094649bb99987bb74db71febfdfa3677b7954a0a05c99d0424e8" dependencies = [ "tokio", "tracing", @@ -6316,43 +6411,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "web-transport" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4703a5ad424f8eca7860903b94f6ed747cf58bebba3081ede78e84493a12440c" -dependencies = [ - "bytes", - "thiserror 1.0.69", - "web-transport-quinn 0.3.4", - "web-transport-wasm", -] - -[[package]] -name = "web-transport-proto" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "974fa1e325e6cc5327de8887f189a441fcff4f8eedcd31ec87f0ef0cc5283fbc" -dependencies = [ - "bytes", - "http", - "thiserror 2.0.18", - "url", -] - -[[package]] -name = "web-transport-proto" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "660175a6d1643adb93b71c4f853d4f20f0fce47f53ae579afe9f7711fe84870d" -dependencies = [ - "bytes", - "http", - "thiserror 2.0.18", - "tokio", - "url", -] - [[package]] name = "web-transport-proto" version = "0.4.0" @@ -6368,28 +6426,24 @@ dependencies = [ ] [[package]] -name = "web-transport-quinn" -version = "0.3.4" +name = "web-transport-proto" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3020b51cda10472a365e42d9a701916d4f04d74cc743de08246ef6a421c2d137" +checksum = "5afe275c02f899650c5497b946552fcc04f3f378bd2c3bc1e8005ff915772b97" dependencies = [ "bytes", - "futures", "http", - "log", - "quinn", - "quinn-proto", - "thiserror 1.0.69", + "sfv", + "thiserror 2.0.18", "tokio", "url", - "web-transport-proto 0.2.8", ] [[package]] name = "web-transport-quinn" -version = "0.10.2" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f44b4e68a3e7adb790793e24ec8b5923a610a8c2df1d6cd58849f9e4759d04" +checksum = "6d356bedff779480f8d88d94bf50a2eb8dedabb84414442f73c16dfee9db55b8" dependencies = [ "bytes", "futures", @@ -6401,32 +6455,19 @@ dependencies = [ "tokio", "tracing", "url", - "web-transport-proto 0.3.1", + "web-transport-proto 0.5.4", "web-transport-trait", ] [[package]] name = "web-transport-trait" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2615c30ed29953bb3727391850279a25c948c0b7a4ed2343d3a78e1d3cce2f7c" +checksum = "802d6aa508f2c63c9050ceabc17265bbf90ed4d6f4e4357e987583883628e79c" dependencies = [ "bytes", ] -[[package]] -name = "web-transport-wasm" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66e8f572ad133af04a5aa4a207d48d3f6a2f1f3006aa1b8f0d774d28c085d699" -dependencies = [ - "bytes", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "web-transport-ws" version = "0.2.4" @@ -7232,3 +7273,18 @@ dependencies = [ "cc", "pkg-config", ] + +[[package]] +name = "zune-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" + +[[package]] +name = "zune-jpeg" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "410e9ecef634c709e3831c2cfdb8d9c32164fae1c67496d5b68fff728eec37fe" +dependencies = [ + "zune-core", +] diff --git a/Cargo.toml b/Cargo.toml index 499b7fc4..098fe3d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,16 +15,16 @@ members = [ exclude = ["examples", "tests", "plugins"] [workspace.dependencies] -streamkit-core = { version = "0.1", path = "crates/core" } -streamkit-nodes = { version = "0.1", path = "crates/nodes" } -streamkit-engine = { version = "0.1", path = "crates/engine" } +streamkit-core = { version = "0.2", path = "crates/core" } +streamkit-nodes = { version = "0.2", path = "crates/nodes" } +streamkit-engine = { version = "0.2", path = "crates/engine" } streamkit-server = { version = "0.2", path = "apps/skit" } streamkit-client = { version = "0.1", path = "apps/skit-cli" } -streamkit-api = { version = "0.1", path = "crates/api" } -streamkit-plugin-wasm = { version = "0.1", path = "crates/plugin-wasm" } -streamkit-plugin-native = { version = "0.1", path = "crates/plugin-native" } +streamkit-api = { version = "0.2", path = "crates/api" } +streamkit-plugin-wasm = { version = "0.2", path = "crates/plugin-wasm" } +streamkit-plugin-native = { version = "0.2", path = "crates/plugin-native" } streamkit-plugin-sdk-wasm = { version = "0.1", path = "sdks/plugin-sdk/wasm/rust" } -streamkit-plugin-sdk-native = { version = "0.1", path = "sdks/plugin-sdk/native" } +streamkit-plugin-sdk-native = { version = "0.2", path = "sdks/plugin-sdk/native" } tracing = "0.1.44" @@ -45,10 +45,17 @@ indexmap = { version = "2.13", features = ["serde"] } opentelemetry = "0.31.0" -# Profile settings for better profiling experience +# Profile settings [profile.release] debug = 1 # Include line tables for better stack traces in profiling +# Maximally-optimised release build — use `cargo build --profile release-lto`. +# CI sticks with the default `release` profile for faster builds. +[profile.release-lto] +inherits = "release" +lto = "fat" # Full cross-crate LTO — eliminates core::ub_checks precondition overhead observed in profiling (vs. "thin") +codegen-units = 1 # Maximum LLVM optimisation visibility + [profile.dev] debug = 2 # Full debug info for development profiling diff --git a/LICENSES/Bitstream-Vera.txt b/LICENSES/Bitstream-Vera.txt new file mode 100644 index 00000000..f353aa2d --- /dev/null +++ b/LICENSES/Bitstream-Vera.txt @@ -0,0 +1,15 @@ +Copyright Copyright (c) 2003 by Bitstream, Inc. All Rights Reserved. Bitstream Vera is a trademark of Bitstream, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of the fonts accompanying this license ("Fonts") and associated documentation files (the "Font Software"), to reproduce and distribute the Font Software, including without limitation the rights to use, copy, merge, publish, distribute, and/or sell copies of the Font Software, and to permit persons to whom the Font Software is furnished to do so, subject to the following conditions: + +The above copyright and trademark notices and this permission notice shall be included in all copies of one or more of the Font Software typefaces. + +The Font Software may be modified, altered, or added to, and in particular the designs of glyphs or characters in the Fonts may be modified and additional glyphs or characters may be added to the Fonts, only if the fonts are renamed to names not containing either the words "Bitstream" or the word "Vera". + +This License becomes null and void to the extent applicable to Fonts or Font Software that has been modified and is distributed under the "Bitstream Vera" names. + +The Font Software may be sold as part of a larger software package but no copy of one or more of the Font Software typefaces may be sold by itself. + +THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL BITSTREAM OR THE GNOME FOUNDATION BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM OTHER DEALINGS IN THE FONT SOFTWARE. + +Except as contained in this notice, the names of Gnome, the Gnome Foundation, and Bitstream Inc., shall not be used in advertising or otherwise to promote the sale, use or other dealings in this Font Software without prior written authorization from the Gnome Foundation or Bitstream Inc., respectively. For further information, contact: fonts at gnome dot org. diff --git a/REUSE.toml b/REUSE.toml index 081d66ca..f69cffc6 100644 --- a/REUSE.toml +++ b/REUSE.toml @@ -69,6 +69,19 @@ precedence = "aggregate" SPDX-FileCopyrightText = "© 2025 StreamKit Contributors" SPDX-License-Identifier = "MPL-2.0" +# Bundled DejaVu fonts (Bitstream Vera / DejaVu permissive license) +[[annotations]] +path = "assets/fonts/*.ttf" +precedence = "override" +SPDX-FileCopyrightText = "© 2003 Bitstream, Inc. All Rights Reserved. DejaVu changes are in public domain." +SPDX-License-Identifier = "Bitstream-Vera" + +[[annotations]] +path = "assets/fonts/LICENSE-DejaVu.txt" +precedence = "override" +SPDX-FileCopyrightText = "© 2003 Bitstream, Inc. All Rights Reserved. DejaVu changes are in public domain." +SPDX-License-Identifier = "Bitstream-Vera" + # Lock files (auto-generated, licensed as MPL-2.0 for simplicity) [[annotations]] path = [ @@ -82,6 +95,20 @@ precedence = "aggregate" SPDX-FileCopyrightText = "© 2025 StreamKit Contributors" SPDX-License-Identifier = "MPL-2.0" +# Bun dependency patches. These are unified-diff files generated by `bun patch`; +# inline SPDX headers would corrupt the patch format. Context/removed lines are +# verbatim from the upstream package (see the package's own license declared in +# its package.json). Added lines are StreamKit contributions under MPL-2.0. +# Annotated with the upstream @moq/hang license for the derived content. +[[annotations]] +path = "ui/patches/**" +precedence = "aggregate" +SPDX-FileCopyrightText = [ + "© 2025 StreamKit Contributors", + "© Luke Curley and moq-dev contributors", +] +SPDX-License-Identifier = "MIT OR Apache-2.0" + # Build artifacts and external dependencies (excluded from REUSE checks) [[annotations]] path = [ diff --git a/ROADMAP.md b/ROADMAP.md index fc41d3ee..8e8732bb 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -44,17 +44,17 @@ These are in place today and will be iterated on (not “added from scratch”): ### Security & Auth (P0) -- **Built-in authentication (JWT)** — First-class authn/authz for **HTTP + WebSocket control + WebTransport/MoQ** - - Local dev: **no auth on loopback** by default - - Real deployments: require auth when binding non-loopback (secure-by-default) - - StreamKit-managed keys by default (auto-generate, store securely, and support rotation) - - **Token issuance flow** for MoQ gateways (so users don’t need external tooling), compatible with the MoQ ecosystem token shape (root-scoped + publish/subscribe permissions) - - UX helpers (UI/CLI) for “copy/paste” publisher/watch URLs with tokens embedded where required by WebTransport today - - **No secret logging**, especially `?jwt=`-style tokens used by WebTransport today +- ~~**Built-in authentication (JWT)** — First-class authn/authz for **HTTP + WebSocket control + WebTransport/MoQ**~~ + - ~~Local dev: **no auth on loopback** by default~~ + - ~~Real deployments: require auth when binding non-loopback (secure-by-default)~~ + - ~~StreamKit-managed keys by default (auto-generate, store securely, and support rotation)~~ + - ~~**Token issuance flow** for MoQ gateways (so users don’t need external tooling), compatible with the MoQ ecosystem token shape (root-scoped + publish/subscribe permissions)~~ + - ~~UX helpers (UI/CLI) for “copy/paste” publisher/watch URLs with tokens embedded where required by WebTransport today~~ + - ~~**No secret logging**, especially `?jwt=`-style tokens used by WebTransport today~~ ### Timing & A/V Sync (P0) -- **Timing contract** — Define canonical semantics for packet timing (`timestamp_us`, `duration_us`) and how nodes/engines must preserve/transform it +- ~~**Timing contract** — Define canonical semantics for packet timing (`timestamp_us`, `duration_us`) and how nodes/engines must preserve/transform it~~ - **A/V sync** — Jitter/drift strategy, drop/late-frame policy, and regression tests (dynamic pipelines) - **Hang/MoQ alignment** — Clear mapping between StreamKit timing metadata and Hang/MoQ timestamps/groups @@ -83,7 +83,7 @@ These are in place today and will be iterated on (not “added from scratch”): ### Capabilities (use-case driven) - **VAD streaming mode** — Zero-latency audio passthrough with per-frame voice activity metadata, enabling downstream nodes to make real-time decisions without buffering delays -- **Multi-input HTTP oneshot** — Accept multiple input files in a single batch request (e.g., multiple audio tracks for mixing, or audio + subtitles for muxing) +- ~~**Multi-input HTTP oneshot** — Accept multiple input files in a single batch request (e.g., multiple audio tracks for mixing, or audio + subtitles for muxing)~~ - **S3 sink node** — Write pipeline output directly to S3-compatible storage - **RTMP input node** — Ingest live streams from OBS, encoders, and other RTMP sources diff --git a/apps/skit/Cargo.toml b/apps/skit/Cargo.toml index 67b4a14b..c2a77b83 100644 --- a/apps/skit/Cargo.toml +++ b/apps/skit/Cargo.toml @@ -115,7 +115,7 @@ jemalloc_pprof = { version = "0.8", features = ["symbolize"], optional = true } dhat = { version = "0.3", optional = true } # MoQ support (optional) -moq-native = { version = "0.12.1", optional = true } +moq-native = { version = "0.13.2", optional = true } async-trait = { workspace = true } # For glob pattern matching in permissions @@ -130,7 +130,7 @@ getrandom = "0.3" aws-lc-rs = "1" # For MoQ auth path matching (optional, with moq feature) -moq-lite = { version = "0.13.0", optional = true } +moq-lite = { version = "0.15.0", optional = true } blake2 = "0.10.6" [features] diff --git a/apps/skit/src/samples.rs b/apps/skit/src/samples.rs index 2b7f9010..dcf98bb6 100644 --- a/apps/skit/src/samples.rs +++ b/apps/skit/src/samples.rs @@ -192,7 +192,7 @@ fn parse_pipeline_metadata( yaml: &str, path: &std::path::Path, ) -> (Option, Option, streamkit_api::EngineMode) { - serde_saphyr::from_str::(yaml).map_or_else( + streamkit_api::yaml::parse_yaml(yaml).map_or_else( |e| { warn!("Failed to parse pipeline metadata from {}: {}", path.display(), e); (None, None, streamkit_api::EngineMode::default()) diff --git a/apps/skit/src/server.rs b/apps/skit/src/server.rs index f914046e..6e621241 100644 --- a/apps/skit/src/server.rs +++ b/apps/skit/src/server.rs @@ -1621,8 +1621,8 @@ async fn create_session_handler( } // Parse and compile the YAML pipeline - let user_pipeline: UserPipeline = serde_saphyr::from_str(&req.yaml) - .map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid YAML: {e}")))?; + let user_pipeline: UserPipeline = + streamkit_api::yaml::parse_yaml(&req.yaml).map_err(|e| (StatusCode::BAD_REQUEST, e))?; let engine_pipeline = compile(user_pipeline) .map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid pipeline: {e}")))?; @@ -1868,13 +1868,10 @@ async fn destroy_session_handler( }; let destroyed_id = session.id.clone(); - if let Err(e) = session.shutdown_and_wait().await { - warn!(session_id = %destroyed_id, error = %e, "Error during engine shutdown"); - } - - info!(session_id = %destroyed_id, "Session destroyed successfully via HTTP"); - // Broadcast event to all WebSocket clients + // Broadcast event to all WebSocket clients BEFORE starting shutdown + // so clients are notified immediately. The session has already been + // removed from the manager so ListSessions will no longer include it. let event = ApiEvent { message_type: MessageType::Event, correlation_id: None, @@ -1884,6 +1881,17 @@ async fn destroy_session_handler( error!("Failed to broadcast SessionDestroyed event: {}", e); } + // Run engine shutdown in a background task so the HTTP response + // returns immediately (shutdown_and_wait has a 10-second timeout). + let shutdown_id = destroyed_id.clone(); + tokio::spawn(async move { + if let Err(e) = session.shutdown_and_wait().await { + warn!(session_id = %shutdown_id, error = %e, "Error during engine shutdown"); + } else { + info!(session_id = %shutdown_id, "Session destroyed successfully via HTTP"); + } + }); + (StatusCode::OK, Json(serde_json::json!({ "session_id": destroyed_id }))).into_response() } @@ -1970,7 +1978,9 @@ async fn parse_config_field( .bytes() .await .map_err(|e| AppError::BadRequest(format!("Failed to read config field: {e}")))?; - serde_saphyr::from_slice(&config_bytes).map_err(Into::into) + let yaml_str = std::str::from_utf8(&config_bytes) + .map_err(|e| AppError::BadRequest(format!("Config is not valid UTF-8: {e}")))?; + streamkit_api::yaml::parse_yaml(yaml_str).map_err(AppError::BadRequest) } /// Build http_input bindings from the pipeline definition. @@ -2288,6 +2298,11 @@ async fn route_multipart_fields( /// Validate that the pipeline has the required nodes for oneshot processing. /// Returns (has_http_input, has_file_read, has_http_output) for logging purposes. +/// +/// Pipelines must have `streamkit::http_output`. For input, they must have at least one of: +/// - `streamkit::http_input` (HTTP streaming mode) +/// - `core::file_reader` (file-based mode) +/// - Neither (generator mode — the pipeline produces its own data, e.g. video::colorbars) fn validate_pipeline_nodes(pipeline_def: &Pipeline) -> Result<(bool, bool, bool), AppError> { let has_http_input = pipeline_def.nodes.values().any(|node| node.kind == "streamkit::http_input"); @@ -2295,13 +2310,6 @@ fn validate_pipeline_nodes(pipeline_def: &Pipeline) -> Result<(bool, bool, bool) pipeline_def.nodes.values().any(|node| node.kind == "streamkit::http_output"); let has_file_read = pipeline_def.nodes.values().any(|node| node.kind == "core::file_reader"); - if !has_http_input && !has_file_read { - return Err(AppError::BadRequest( - "Pipeline must contain at least one 'streamkit::http_input' or 'core::file_reader' node for oneshot processing" - .to_string(), - )); - } - if !has_http_output { return Err(AppError::BadRequest( "Pipeline must contain one 'streamkit::http_output' node for oneshot processing" @@ -2592,7 +2600,7 @@ async fn process_oneshot_pipeline_handler( tracing::info!( "Pipeline validation passed: mode={}, has_http_input={}, has_file_read={}, has_http_output={}", - if has_http_input { "http-streaming" } else { "file-based" }, + if has_http_input { "http-streaming" } else if has_file_read { "file-based" } else { "generator" }, has_http_input, has_file_read, has_http_output @@ -3282,7 +3290,7 @@ fn start_moq_webtransport_acceptor( match validate_moq_auth(&auth_state, &path, jwt_param).await { Ok(ctx) => Some(ctx), Err(status) => { - let _ = request.reject(status).await; + let _ = request.close(status.as_u16()).await; return; }, } diff --git a/apps/skit/src/session.rs b/apps/skit/src/session.rs index 680d8331..ba94ec8c 100644 --- a/apps/skit/src/session.rs +++ b/apps/skit/src/session.rs @@ -262,6 +262,35 @@ impl Session { ); }); + // Subscribe to view data updates from the engine + let mut view_data_rx = engine_handle + .subscribe_view_data() + .await + .map_err(|e| format!("Failed to subscribe to view data updates: {e}"))?; + + // Spawn task to forward view data updates to WebSocket clients + let session_id_for_view_data = session_id.clone(); + let event_tx_for_view_data = event_tx.clone(); + tokio::spawn(async move { + while let Some(update) = view_data_rx.recv().await { + let event = ApiEvent { + message_type: MessageType::Event, + correlation_id: None, + payload: EventPayload::NodeViewDataUpdated { + session_id: session_id_for_view_data.clone(), + node_id: update.node_id, + data: update.data, + timestamp: system_time_to_rfc3339(update.timestamp), + }, + }; + let _ = event_tx_for_view_data.send(event); + } + tracing::debug!( + session_id = %session_id_for_view_data, + "View data forwarding task ended" + ); + }); + // Subscribe to telemetry events from the engine let mut telemetry_rx = engine_handle .subscribe_telemetry() diff --git a/apps/skit/src/websocket.rs b/apps/skit/src/websocket.rs index fd72ba88..a18d4966 100644 --- a/apps/skit/src/websocket.rs +++ b/apps/skit/src/websocket.rs @@ -274,7 +274,8 @@ pub async fn handle_websocket( | EventPayload::NodeRemoved { session_id, .. } | EventPayload::ConnectionAdded { session_id, .. } | EventPayload::ConnectionRemoved { session_id, .. } - | EventPayload::NodeTelemetry { session_id, .. } => { + | EventPayload::NodeTelemetry { session_id, .. } + | EventPayload::NodeViewDataUpdated { session_id, .. } => { visible_session_ids.contains(session_id) } } diff --git a/apps/skit/src/websocket_handlers.rs b/apps/skit/src/websocket_handlers.rs index e9b76f45..0472562c 100644 --- a/apps/skit/src/websocket_handlers.rs +++ b/apps/skit/src/websocket_handlers.rs @@ -223,13 +223,10 @@ async fn handle_destroy_session( }; let destroyed_id = session.id.clone(); - if let Err(e) = session.shutdown_and_wait().await { - warn!(session_id = %destroyed_id, error = %e, "Error during engine shutdown"); - } - - info!(session_id = %destroyed_id, "Session destroyed successfully"); - - // Broadcast event to all clients + // Broadcast event to all clients BEFORE starting shutdown so the + // response and event reach clients immediately. The session has + // already been removed from the manager so ListSessions will no + // longer include it. let event = ApiEvent { message_type: MessageType::Event, correlation_id: None, @@ -239,6 +236,19 @@ async fn handle_destroy_session( error!("Failed to broadcast SessionDestroyed event: {}", e); } + // Run engine shutdown in a background task so we don't block the + // WebSocket handler (shutdown_and_wait has a 10-second timeout which + // would stall the entire WS select loop and cause the client's + // 5-second request timeout to fire first). + let shutdown_id = destroyed_id.clone(); + tokio::spawn(async move { + if let Err(e) = session.shutdown_and_wait().await { + warn!(session_id = %shutdown_id, error = %e, "Error during engine shutdown"); + } else { + info!(session_id = %shutdown_id, "Session destroyed successfully"); + } + }); + Some(ResponsePayload::SessionDestroyed { session_id: destroyed_id }) } diff --git a/assets/fonts/DejaVuSans-Bold.ttf b/assets/fonts/DejaVuSans-Bold.ttf new file mode 100644 index 00000000..06db62c6 Binary files /dev/null and b/assets/fonts/DejaVuSans-Bold.ttf differ diff --git a/assets/fonts/DejaVuSans.ttf b/assets/fonts/DejaVuSans.ttf new file mode 100644 index 00000000..2fbbe69e Binary files /dev/null and b/assets/fonts/DejaVuSans.ttf differ diff --git a/assets/fonts/DejaVuSansMono-Bold.ttf b/assets/fonts/DejaVuSansMono-Bold.ttf new file mode 100644 index 00000000..b210eb53 Binary files /dev/null and b/assets/fonts/DejaVuSansMono-Bold.ttf differ diff --git a/assets/fonts/DejaVuSansMono.ttf b/assets/fonts/DejaVuSansMono.ttf new file mode 100644 index 00000000..041cffc4 Binary files /dev/null and b/assets/fonts/DejaVuSansMono.ttf differ diff --git a/assets/fonts/DejaVuSerif-Bold.ttf b/assets/fonts/DejaVuSerif-Bold.ttf new file mode 100644 index 00000000..8162112c Binary files /dev/null and b/assets/fonts/DejaVuSerif-Bold.ttf differ diff --git a/assets/fonts/DejaVuSerif.ttf b/assets/fonts/DejaVuSerif.ttf new file mode 100644 index 00000000..1b629773 Binary files /dev/null and b/assets/fonts/DejaVuSerif.ttf differ diff --git a/assets/fonts/LICENSE-DejaVu.txt b/assets/fonts/LICENSE-DejaVu.txt new file mode 100644 index 00000000..8d05fc34 --- /dev/null +++ b/assets/fonts/LICENSE-DejaVu.txt @@ -0,0 +1,52 @@ +DejaVu Fonts — License +====================== + +Fonts are (c) Bitstream (see below). DejaVu changes are in public domain. +Glyphs imported from Arev fonts are (c) Tavmjong Bah (see below) + +Bitstream Vera Fonts Copyright +------------------------------ + +Copyright (c) 2003 by Bitstream, Inc. All Rights Reserved. Bitstream Vera is +a trademark of Bitstream, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of the fonts accompanying this license ("Fonts") and associated +documentation files (the "Font Software"), to reproduce and distribute the +Font Software, including without limitation the rights to use, copy, merge, +publish, distribute, and/or sell copies of the Font Software, and to permit +persons to whom the Font Software is furnished to do so, subject to the +following conditions: + +The above copyright and trademark notices and this permission notice shall +be included in all copies of one or more of the Font Software typefaces. + +The Font Software may be modified, altered, or added to, and in particular +the designs of glyphs or characters in the Fonts may be modified and +additional glyphs or characters may be added to the Fonts, only if the fonts +are renamed to names not containing either the words "Bitstream" or the word +"Vera". + +This License becomes null and void to the extent applicable to Fonts or Font +Software that has been modified and is distributed under the "Bitstream +Vera" names. + +The Font Software may be sold as part of a larger software package but no +copy of one or more of the Font Software typefaces may be sold by itself. + +THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF COPYRIGHT, PATENT, +TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL BITSTREAM OR THE GNOME +FOUNDATION BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, INCLUDING +ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF +THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM OTHER DEALINGS IN THE +FONT SOFTWARE. + +Except as contained in this notice, the names of Gnome, the Gnome +Foundation, and Bitstream Inc., shall not be used in advertising or +otherwise to promote the sale, use or other dealings in this Font Software +without prior written authorization from the Gnome Foundation or Bitstream +Inc., respectively. For further information, contact: fonts at gnome dot +org. diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 72e1269a..96d4aa98 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "streamkit-api" -version = "0.1.1" +version = "0.2.0" edition = "2021" rust-version = "1.92" authors = ["Claudio Costa ", "StreamKit Contributors"] @@ -13,7 +13,7 @@ categories = ["multimedia", "network-programming", "api-bindings"] readme = "README.md" [dependencies] -streamkit-core = { version = "0.1.0", path = "../core" } +streamkit-core = { version = "0.2.0", path = "../core" } serde = { version = "1.0.228", features = ["derive", "rc"] } serde_json = "1.0" serde-saphyr = "0.0.11" diff --git a/crates/api/src/bin/generate_ts_types.rs b/crates/api/src/bin/generate_ts_types.rs index 640a6c2d..5c663c6d 100644 --- a/crates/api/src/bin/generate_ts_types.rs +++ b/crates/api/src/bin/generate_ts_types.rs @@ -10,7 +10,9 @@ use std::fs; use std::path::Path; use streamkit_core::control::NodeControlMessage; use streamkit_core::types::{ - AudioFormat, PacketMetadata, PacketType, SampleFormat, TranscriptionData, TranscriptionSegment, + AudioCodec, AudioFormat, EncodedAudioFormat, EncodedVideoFormat, PacketMetadata, PacketType, + PixelFormat, SampleFormat, TranscriptionData, TranscriptionSegment, VideoBitstreamFormat, + VideoCodec, VideoFormat, }; use ts_rs::TS; @@ -19,6 +21,13 @@ fn main() -> Result<(), Box> { // streamkit-core types format!("// streamkit-core\nexport {}", SampleFormat::decl()), format!("export {}", AudioFormat::decl()), + format!("export {}", PixelFormat::decl()), + format!("export {}", VideoFormat::decl()), + format!("export {}", AudioCodec::decl()), + format!("export {}", VideoCodec::decl()), + format!("export {}", VideoBitstreamFormat::decl()), + format!("export {}", EncodedAudioFormat::decl()), + format!("export {}", EncodedVideoFormat::decl()), format!("export {}", PacketMetadata::decl()), format!("export {}", TranscriptionSegment::decl()), format!("export {}", TranscriptionData::decl()), diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 03781ea7..ed9be8cf 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -433,6 +433,17 @@ pub enum EventPayload { to_node: String, to_pin: String, }, + // --- View Data Events --- + /// A node's view data has been updated (e.g., compositor resolved layout). + /// View data carries structured JSON that the frontend interprets per-node-type. + NodeViewDataUpdated { + session_id: String, + node_id: String, + #[ts(type = "JsonValue")] + data: serde_json::Value, + /// ISO 8601 formatted timestamp + timestamp: String, + }, // --- Telemetry Events --- /// Telemetry event from a node (transcription results, VAD events, LLM responses, etc.). /// The data payload contains event-specific fields including event_type for filtering. diff --git a/crates/api/src/yaml.rs b/crates/api/src/yaml.rs index 6b194854..f4c5a13c 100644 --- a/crates/api/src/yaml.rs +++ b/crates/api/src/yaml.rs @@ -77,6 +77,14 @@ pub enum Needs { None, Single(NeedsDependency), Multiple(Vec), + /// Map variant: keys are **target input pin names**. + /// Enables explicit pin targeting, e.g. + /// ```yaml + /// needs: + /// video: vp9_encoder + /// audio: opus_encoder + /// ``` + Map(IndexMap), } /// The top-level structure for a user-facing pipeline definition. @@ -105,6 +113,22 @@ pub enum UserPipeline { }, } +/// Parse a YAML string into a [`UserPipeline`]. +/// +/// Uses a two-step approach (YAML → `serde_json::Value` → `UserPipeline`) +/// to work around a `serde_saphyr` limitation where deeply nested +/// structures fail to deserialize inside `#[serde(untagged)]` enums. +/// +/// # Errors +/// +/// Returns an error if the YAML is malformed or doesn't match the +/// `UserPipeline` schema. +pub fn parse_yaml(yaml: &str) -> Result { + let json_value: serde_json::Value = + serde_saphyr::from_str(yaml).map_err(|e| format!("Invalid YAML: {e}"))?; + serde_json::from_value(json_value).map_err(|e| format!("Invalid pipeline: {e}")) +} + /// "Compiles" the user-facing pipeline format into the explicit format the engine requires. /// /// # Errors @@ -227,6 +251,7 @@ fn detect_cycles(user_nodes: &IndexMap) -> Result<(), String> Needs::None => vec![], Needs::Single(dep) => vec![dep.node_and_pin().0], Needs::Multiple(deps) => deps.iter().map(|d| d.node_and_pin().0).collect(), + Needs::Map(map) => map.values().map(|d| d.node_and_pin().0).collect(), }; for dep_name in dependencies { @@ -277,13 +302,35 @@ fn compile_dag( let mut connections = Vec::new(); for (node_name, node_def) in &user_nodes { - let dependencies: Vec<&NeedsDependency> = match &node_def.needs { + // Collect dependencies and resolve target pin names. + // For Map variant, the map key is the explicit target pin name. + // For Single/Multiple, pin names are auto-generated ("in" / "in_N"). + enum DepEntry<'a> { + Auto { idx: usize, total: usize, dep: &'a NeedsDependency }, + Named { pin: &'a str, dep: &'a NeedsDependency }, + } + + let entries: Vec> = match &node_def.needs { Needs::None => vec![], - Needs::Single(dep) => vec![dep], - Needs::Multiple(deps) => deps.iter().collect(), + Needs::Single(dep) => vec![DepEntry::Auto { idx: 0, total: 1, dep }], + Needs::Multiple(deps) => deps + .iter() + .enumerate() + .map(|(idx, dep)| DepEntry::Auto { idx, total: deps.len(), dep }) + .collect(), + Needs::Map(map) => { + map.iter().map(|(pin, dep)| DepEntry::Named { pin: pin.as_str(), dep }).collect() + }, }; - for (idx, dep) in dependencies.iter().enumerate() { + for entry in &entries { + let (dep, to_pin) = match entry { + DepEntry::Auto { idx, total, dep } => { + let pin = if *total > 1 { format!("in_{idx}") } else { "in".to_string() }; + (*dep, pin) + }, + DepEntry::Named { pin, dep } => (*dep, (*pin).to_string()), + }; let (dep_name, from_pin) = dep.node_and_pin(); // Validate that the referenced node exists @@ -293,10 +340,6 @@ fn compile_dag( )); } - // Use numbered input pins (in_0, in_1, etc.) when there are multiple inputs - let to_pin = - if dependencies.len() > 1 { format!("in_{idx}") } else { "in".to_string() }; - connections.push(Connection { from_node: dep_name.to_string(), from_pin: from_pin.unwrap_or("out").to_string(), diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 49e798dd..b10edac8 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "streamkit-core" -version = "0.1.0" +version = "0.2.0" edition = "2021" rust-version = "1.92" authors = ["Claudio Costa ", "StreamKit Contributors"] diff --git a/crates/core/src/frame_pool.rs b/crates/core/src/frame_pool.rs index 63ecd614..9489df8e 100644 --- a/crates/core/src/frame_pool.rs +++ b/crates/core/src/frame_pool.rs @@ -139,6 +139,11 @@ impl FramePool { /// Get pooled storage for at least `min_len` elements. /// /// If `min_len` doesn't fit in any bucket, returns a non-pooled buffer of exact size. + /// + /// On the first miss for a given bucket (cold start), an extra buffer is + /// allocated and placed into the pool so that the *next* `get()` at the + /// same size is a hit. This amortises cold-start allocation cost without + /// pre-allocating every bucket size up front. pub fn get(&self, min_len: usize) -> PooledFrameData { let (handle, bucket_idx, bucket_size, maybe_buf) = { let Ok(mut guard) = self.inner.lock() else { @@ -154,6 +159,12 @@ impl FramePool { guard.hits += 1; } else { guard.misses += 1; + // Lazy preallocate: on first miss for this bucket, seed the + // pool with one extra buffer so subsequent gets are hits. + if guard.buckets[bucket_idx].is_empty() && guard.buckets[bucket_idx].capacity() == 0 + { + guard.buckets[bucket_idx].push(vec![T::default(); bucket_size]); + } } (self.handle(), bucket_idx, bucket_size, buf) }; @@ -316,6 +327,24 @@ impl FramePool { } } +pub type VideoFramePool = FramePool; +pub type PooledVideoData = PooledFrameData; + +pub const DEFAULT_VIDEO_BUCKET_SIZES: &[usize] = &[ + 86_400, 230_400, 345_600, 921_600, 1_382_400, 3_110_400, 3_686_400, 8_294_400, 12_441_600, + 33_177_600, +]; +pub const DEFAULT_VIDEO_MAX_BUFFERS_PER_BUCKET: usize = 16; + +impl FramePool { + pub fn video_default() -> Self { + Self::with_buckets( + DEFAULT_VIDEO_BUCKET_SIZES.to_vec(), + DEFAULT_VIDEO_MAX_BUFFERS_PER_BUCKET, + ) + } +} + #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 1bae09fa..e9e0fd95 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -66,6 +66,7 @@ pub mod stats; pub mod telemetry; pub mod timing; pub mod types; +pub mod view_data; // Convenience re-exports for commonly used types // These are the most frequently used types in node implementations @@ -93,6 +94,9 @@ pub use stats::{NodeStats, NodeStatsUpdate}; // Telemetry pub use telemetry::{TelemetryConfig, TelemetryEmitter, TelemetryEvent}; +// View data +pub use view_data::NodeViewDataUpdate; + // Timing helpers pub use timing::*; @@ -103,9 +107,12 @@ pub use pins::{InputPin, OutputPin, PinCardinality}; pub use helpers::{config_helpers, packet_helpers}; pub use state::state_helpers; pub use telemetry::telemetry_helpers; +pub use view_data::view_data_helpers; // Frame pooling (optional hot-path optimization) -pub use frame_pool::{AudioFramePool, FramePool, PooledFrameData, PooledSamples}; +pub use frame_pool::{ + AudioFramePool, FramePool, PooledFrameData, PooledSamples, PooledVideoData, VideoFramePool, +}; // Node buffer configuration pub use node_config::{ diff --git a/crates/core/src/node.rs b/crates/core/src/node.rs index b24f331e..4527f123 100644 --- a/crates/core/src/node.rs +++ b/crates/core/src/node.rs @@ -16,8 +16,9 @@ use crate::pins::{InputPin, OutputPin, PinManagementMessage, PinUpdate}; use crate::state::NodeStateUpdate; use crate::stats::NodeStatsUpdate; use crate::telemetry::TelemetryEvent; -use crate::types::Packet; -use crate::AudioFramePool; +use crate::types::{Packet, PacketType}; +use crate::view_data::NodeViewDataUpdate; +use crate::{AudioFramePool, VideoFramePool}; use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; @@ -58,6 +59,10 @@ pub enum OutputSendError { /// The downstream channel (direct) or engine channel (routed) is closed. #[error("output channel closed for pin '{pin_name}' on node '{node_name}'")] ChannelClosed { node_name: String, pin_name: String }, + + /// The downstream channel is full (non-blocking send). + #[error("output channel full for pin '{pin_name}' on node '{node_name}'")] + ChannelFull { node_name: String, pin_name: String }, } impl OutputSender { @@ -84,6 +89,67 @@ impl OutputSender { } } + /// Non-blocking send from a specific output pin. + /// + /// Returns [`OutputSendError::ChannelFull`] when the downstream channel + /// has no capacity — callers may drop the packet and continue. + /// Returns [`OutputSendError::ChannelClosed`] or [`OutputSendError::PinNotFound`] + /// for permanent errors — callers should stop processing. + /// + /// Used by real-time nodes (e.g. compositor) that prefer dropping a frame + /// over stalling and accumulating latency. + pub fn try_send(&mut self, pin_name: &str, packet: Packet) -> Result<(), OutputSendError> { + use tokio::sync::mpsc::error::TrySendError; + + match &self.routing { + OutputRouting::Direct(senders) => { + if let Some(sender) = senders.get(pin_name) { + match sender.try_send(packet) { + Ok(()) => {}, + Err(TrySendError::Full(_)) => { + return Err(OutputSendError::ChannelFull { + node_name: self.node_name.to_string(), + pin_name: pin_name.to_string(), + }); + }, + Err(TrySendError::Closed(_)) => { + return Err(OutputSendError::ChannelClosed { + node_name: self.node_name.to_string(), + pin_name: pin_name.to_string(), + }); + }, + } + } else { + return Err(OutputSendError::PinNotFound { + node_name: self.node_name.to_string(), + pin_name: pin_name.to_string(), + }); + } + }, + OutputRouting::Routed(engine_tx) => { + let engine_tx = engine_tx.clone(); + let cached_pin = self.get_cached_pin_name(pin_name); + let message = (self.node_name.clone(), cached_pin, packet); + match engine_tx.try_send(message) { + Ok(()) => {}, + Err(TrySendError::Full(_)) => { + return Err(OutputSendError::ChannelFull { + node_name: self.node_name.to_string(), + pin_name: pin_name.to_string(), + }); + }, + Err(TrySendError::Closed(_)) => { + return Err(OutputSendError::ChannelClosed { + node_name: self.node_name.to_string(), + pin_name: pin_name.to_string(), + }); + }, + } + }, + } + Ok(()) + } + /// Sends a packet from a specific output pin of this node. /// Returns `Ok(())` if sent successfully. /// @@ -190,6 +256,15 @@ pub struct InitContext { /// The context provided by the engine to a node when it is run. pub struct NodeContext { pub inputs: HashMap>, + /// The [`PacketType`] that each connected input pin will receive, keyed by + /// pin name. Populated by the graph builder from the upstream node's + /// output type so that nodes can make decisions based on the connected + /// media type without having to inspect packets at runtime. + /// + /// Only contains entries for *connected* pins (unconnected pins are absent). + /// May be empty for dynamic pipelines where connections are made after the + /// node is already running. + pub input_types: HashMap, pub control_rx: mpsc::Receiver, pub output_sender: OutputSender, pub batch_size: usize, @@ -223,6 +298,14 @@ pub struct NodeContext { /// Nodes that produce audio frames (decoders, resamplers, mixers) may use this to /// amortize `Vec` allocations. If `None`, nodes should fall back to allocating. pub audio_pool: Option>, + /// Optional per-pipeline video buffer pool for hot-path allocations. + /// + /// Nodes that produce video frames (decoders, scalers, compositors) may use this to + /// amortize `Vec` allocations. If `None`, nodes should fall back to allocating. + pub video_pool: Option>, + /// Channel for the node to emit structured view data for frontend consumption. + /// Like stats_tx, this is optional and best-effort. + pub view_data_tx: Option>, } impl NodeContext { diff --git a/crates/core/src/packet_meta.rs b/crates/core/src/packet_meta.rs index 6f5b2c15..35754364 100644 --- a/crates/core/src/packet_meta.rs +++ b/crates/core/src/packet_meta.rs @@ -37,7 +37,7 @@ pub enum Compatibility { #[derive(Debug, Clone, Serialize, Deserialize, TS)] #[ts(export)] pub struct PacketTypeMeta { - /// Variant identifier (e.g., "RawAudio", "OpusAudio", "Binary", "Any"). + /// Variant identifier (e.g., "RawAudio", "EncodedAudio", "Binary", "Any"). pub id: String, /// Human-friendly default label. pub label: String, @@ -80,13 +80,6 @@ pub fn packet_type_registry() -> &'static [PacketTypeMeta] { display_template: None, compatibility: Compatibility::Exact, }, - PacketTypeMeta { - id: "OpusAudio".into(), - label: "Opus Audio".into(), - color: "#ff6b6b".into(), - display_template: None, - compatibility: Compatibility::Exact, - }, PacketTypeMeta { id: "RawAudio".into(), label: "Raw Audio".into(), @@ -108,6 +101,67 @@ pub fn packet_type_registry() -> &'static [PacketTypeMeta] { ], }, }, + PacketTypeMeta { + id: "RawVideo".into(), + label: "Raw Video".into(), + color: "#1abc9c".into(), + display_template: Some("Raw Video ({width|*}x{height|*}, {pixel_format})".into()), + compatibility: Compatibility::StructFieldWildcard { + fields: vec![ + FieldRule { + name: "width".into(), + wildcard_value: Some(serde_json::Value::Null), + }, + FieldRule { + name: "height".into(), + wildcard_value: Some(serde_json::Value::Null), + }, + FieldRule { name: "pixel_format".into(), wildcard_value: None }, + ], + }, + }, + PacketTypeMeta { + id: "EncodedAudio".into(), + label: "Encoded Audio".into(), + color: "#ff6b6b".into(), + display_template: Some("Encoded Audio ({codec})".into()), + compatibility: Compatibility::StructFieldWildcard { + fields: vec![ + FieldRule { name: "codec".into(), wildcard_value: None }, + FieldRule { + name: "codec_private".into(), + wildcard_value: Some(serde_json::Value::Null), + }, + ], + }, + }, + PacketTypeMeta { + id: "EncodedVideo".into(), + label: "Encoded Video".into(), + color: "#2980b9".into(), + display_template: Some("Encoded Video ({codec})".into()), + compatibility: Compatibility::StructFieldWildcard { + fields: vec![ + FieldRule { name: "codec".into(), wildcard_value: None }, + FieldRule { + name: "bitstream_format".into(), + wildcard_value: Some(serde_json::Value::Null), + }, + FieldRule { + name: "codec_private".into(), + wildcard_value: Some(serde_json::Value::Null), + }, + FieldRule { + name: "profile".into(), + wildcard_value: Some(serde_json::Value::Null), + }, + FieldRule { + name: "level".into(), + wildcard_value: Some(serde_json::Value::Null), + }, + ], + }, + }, PacketTypeMeta { id: "Transcription".into(), label: "Transcription".into(), @@ -190,15 +244,24 @@ pub fn can_connect(output: &PacketType, input: &PacketType, registry: &[PacketTy }; fields.iter().all(|f| { - let Some(av) = out_map.get(&f.name) else { - return false; + let wildcard = f.wildcard_value.as_ref(); + let av = match out_map.get(&f.name) { + Some(value) => value, + None => match wildcard { + Some(value) => value, + None => return false, + }, }; - let Some(bv) = in_map.get(&f.name) else { - return false; + let bv = match in_map.get(&f.name) { + Some(value) => value, + None => match wildcard { + Some(value) => value, + None => return false, + }, }; // If either equals the wildcard, it matches - if let Some(wild) = &f.wildcard_value { + if let Some(wild) = wildcard { if av == wild || bv == wild { return true; } @@ -219,3 +282,59 @@ pub fn can_connect_any( ) -> bool { inputs.iter().any(|inp| can_connect(output, inp, registry)) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{AudioCodec, EncodedAudioFormat, PixelFormat, VideoFormat}; + + #[test] + fn raw_video_wildcard_dimensions() { + let registry = packet_type_registry(); + let exact = + VideoFormat { width: Some(1920), height: Some(1080), pixel_format: PixelFormat::I420 }; + let wildcard = VideoFormat { width: None, height: None, pixel_format: PixelFormat::I420 }; + let mismatched = + VideoFormat { width: Some(1280), height: Some(720), pixel_format: PixelFormat::I420 }; + let different_format = + VideoFormat { width: Some(1920), height: Some(1080), pixel_format: PixelFormat::Rgba8 }; + + assert!(can_connect( + &PacketType::RawVideo(exact.clone()), + &PacketType::RawVideo(exact.clone()), + registry + )); + assert!(can_connect( + &PacketType::RawVideo(exact.clone()), + &PacketType::RawVideo(wildcard.clone()), + registry + )); + assert!(can_connect( + &PacketType::RawVideo(wildcard), + &PacketType::RawVideo(exact.clone()), + registry + )); + assert!(!can_connect( + &PacketType::RawVideo(exact.clone()), + &PacketType::RawVideo(mismatched), + registry + )); + assert!(!can_connect( + &PacketType::RawVideo(exact), + &PacketType::RawVideo(different_format), + registry + )); + } + + #[test] + fn encoded_audio_optional_fields() { + let registry = packet_type_registry(); + let format = EncodedAudioFormat { codec: AudioCodec::Opus, codec_private: None }; + + assert!(can_connect( + &PacketType::EncodedAudio(format.clone()), + &PacketType::EncodedAudio(format), + registry + )); + } +} diff --git a/crates/core/src/telemetry.rs b/crates/core/src/telemetry.rs index 33d38ec9..c0d11540 100644 --- a/crates/core/src/telemetry.rs +++ b/crates/core/src/telemetry.rs @@ -99,6 +99,7 @@ impl TelemetryEvent { timestamp_us: Some(timestamp_us), duration_us: None, sequence: None, + keyframe: None, }), }, } diff --git a/crates/core/src/timing.rs b/crates/core/src/timing.rs index 670046a8..cea5b9b6 100644 --- a/crates/core/src/timing.rs +++ b/crates/core/src/timing.rs @@ -104,6 +104,8 @@ pub fn merge_metadata<'a, I: Iterator>( let mut ts = None; let mut dur = None; let mut seq = None; + let mut keyframe = None; + let mut keyframe_conflict = false; for m in iter { if let Some(t) = m.timestamp_us { ts = Some(ts.map_or(t, |prev: u64| prev.min(t))); @@ -114,9 +116,21 @@ pub fn merge_metadata<'a, I: Iterator>( if let Some(s) = m.sequence { seq = Some(seq.map_or(s, |prev: u64| prev.max(s))); } + if !keyframe_conflict { + if let Some(k) = m.keyframe { + match keyframe { + None => keyframe = Some(k), + Some(existing) if existing == k => {}, + Some(_) => { + keyframe = None; + keyframe_conflict = true; + }, + } + } + } } - if ts.is_some() || dur.is_some() || seq.is_some() { - Some(PacketMetadata { timestamp_us: ts, duration_us: dur, sequence: seq }) + if ts.is_some() || dur.is_some() || seq.is_some() || keyframe.is_some() { + Some(PacketMetadata { timestamp_us: ts, duration_us: dur, sequence: seq, keyframe }) } else { None } diff --git a/crates/core/src/types.rs b/crates/core/src/types.rs index d6aeef6c..4999087b 100644 --- a/crates/core/src/types.rs +++ b/crates/core/src/types.rs @@ -12,7 +12,8 @@ //! - Transcription types for speech processing //! - Extensible custom packet types for plugins -use crate::frame_pool::PooledSamples; +use crate::error::StreamKitError; +use crate::frame_pool::{PooledSamples, PooledVideoData}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; @@ -37,6 +38,72 @@ pub struct AudioFormat { pub sample_format: SampleFormat, } +/// Describes the pixel format of raw video frames. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, JsonSchema, TS)] +#[ts(export)] +pub enum PixelFormat { + Rgba8, + I420, + Nv12, +} + +/// Contains the detailed metadata for a raw video stream. +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, TS)] +#[ts(export)] +pub struct VideoFormat { + pub width: Option, + pub height: Option, + pub pixel_format: PixelFormat, +} + +/// Supported encoded audio codecs. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, JsonSchema, TS)] +#[ts(export)] +pub enum AudioCodec { + Opus, +} + +/// Supported encoded video codecs. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, JsonSchema, TS)] +#[ts(export)] +pub enum VideoCodec { + Vp9, + H264, + Av1, +} + +/// Bitstream format hints for video codecs (primarily H264). +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, JsonSchema, TS)] +#[ts(export)] +pub enum VideoBitstreamFormat { + AnnexB, + Avcc, +} + +/// Encoded audio format details (extensible for codec-specific config). +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, TS)] +#[ts(export)] +pub struct EncodedAudioFormat { + pub codec: AudioCodec, + #[serde(skip_serializing_if = "Option::is_none")] + pub codec_private: Option>, +} + +/// Encoded video format details (extensible for codec-specific config). +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, TS)] +#[ts(export)] +pub struct EncodedVideoFormat { + pub codec: VideoCodec, + #[serde(skip_serializing_if = "Option::is_none")] + pub bitstream_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub codec_private: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub profile: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub level: Option, +} + /// Optional timing and sequencing metadata that can be attached to packets. /// Used for pacing, synchronization, and A/V alignment. See `timing` module for /// canonical semantics (media-time epoch, monotonicity, and preservation rules). @@ -49,6 +116,8 @@ pub struct PacketMetadata { pub duration_us: Option, /// Sequence number for ordering and detecting loss pub sequence: Option, + /// Keyframe flag for encoded video packets (and raw frames if applicable) + pub keyframe: Option, } /// Describes the *type* of data, used for pre-flight pipeline validation. @@ -57,8 +126,12 @@ pub struct PacketMetadata { pub enum PacketType { /// Raw, uncompressed audio with a specific format. RawAudio(AudioFormat), - /// Compressed Opus audio. - OpusAudio, + /// Raw, uncompressed video with a specific format. + RawVideo(VideoFormat), + /// Encoded audio with codec metadata. + EncodedAudio(EncodedAudioFormat), + /// Encoded video with codec metadata. + EncodedVideo(EncodedVideoFormat), /// Plain text. Text, /// Structured transcription data with timestamps and metadata. @@ -93,6 +166,7 @@ pub enum PacketType { #[derive(Debug, Clone, Serialize)] pub enum Packet { Audio(AudioFrame), + Video(VideoFrame), /// Text payload (Arc-backed to make fan-out cloning cheap). Text(Arc), /// Transcription payload (Arc-backed to make fan-out cloning cheap). @@ -175,6 +249,138 @@ pub struct TranscriptionData { pub metadata: Option, } +/// Layout information for a single video plane. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct VideoPlane { + pub offset: usize, + pub stride: usize, + pub width: u32, + pub height: u32, +} + +/// Packed layout for a video frame. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct VideoLayout { + plane_count: usize, + planes: [VideoPlane; 3], + total_bytes: usize, + stride_align: u32, +} + +impl VideoLayout { + pub fn packed(width: u32, height: u32, pixel_format: PixelFormat) -> Self { + Self::aligned(width, height, pixel_format, 1) + } + + pub fn aligned(width: u32, height: u32, pixel_format: PixelFormat, stride_align: u32) -> Self { + const EMPTY_PLANE: VideoPlane = VideoPlane { offset: 0, stride: 0, width: 0, height: 0 }; + let mut planes = [EMPTY_PLANE; 3]; + let stride_align = stride_align.max(1); + let stride_align_usize = stride_align as usize; + + let (plane_count, total_bytes) = match pixel_format { + PixelFormat::Rgba8 => { + let stride = align_up(width as usize * 4, stride_align_usize); + let size = stride * height as usize; + planes[0] = VideoPlane { offset: 0, stride, width, height }; + (1, size) + }, + PixelFormat::I420 => { + let luma_stride = align_up(width as usize, stride_align_usize); + let luma_size = luma_stride * height as usize; + let chroma_width = (width + 1) as usize / 2; + let chroma_height = (height + 1) as usize / 2; + let chroma_stride = align_up(chroma_width, stride_align_usize); + let chroma_size = chroma_stride * chroma_height; + + planes[0] = VideoPlane { offset: 0, stride: luma_stride, width, height }; + planes[1] = VideoPlane { + offset: luma_size, + stride: chroma_stride, + width: chroma_width as u32, + height: chroma_height as u32, + }; + planes[2] = VideoPlane { + offset: luma_size + chroma_size, + stride: chroma_stride, + width: chroma_width as u32, + height: chroma_height as u32, + }; + + (3, luma_size + chroma_size * 2) + }, + PixelFormat::Nv12 => { + let luma_stride = align_up(width as usize, stride_align_usize); + let luma_size = luma_stride * height as usize; + let chroma_width = (width + 1) as usize / 2 * 2; // interleaved UV pairs + let chroma_height = (height + 1) as usize / 2; + let chroma_stride = align_up(chroma_width, stride_align_usize); + let chroma_size = chroma_stride * chroma_height; + + planes[0] = VideoPlane { offset: 0, stride: luma_stride, width, height }; + planes[1] = VideoPlane { + offset: luma_size, + stride: chroma_stride, + width: chroma_width as u32, + height: chroma_height as u32, + }; + + (2, luma_size + chroma_size) + }, + }; + + Self { plane_count, planes, total_bytes, stride_align } + } + + pub const fn plane_count(&self) -> usize { + self.plane_count + } + + pub fn planes(&self) -> &[VideoPlane] { + &self.planes[..self.plane_count] + } + + pub fn plane(&self, index: usize) -> Option { + if index < self.plane_count { + Some(self.planes[index]) + } else { + None + } + } + + pub const fn total_bytes(&self) -> usize { + self.total_bytes + } + + pub const fn stride_align(&self) -> u32 { + self.stride_align + } +} + +fn align_up(value: usize, align: usize) -> usize { + if align <= 1 { + value + } else { + value.div_ceil(align) * align + } +} + +/// A view into a single video plane. +pub struct VideoPlaneRef<'a> { + pub data: &'a [u8], + pub stride: usize, + pub width: u32, + pub height: u32, +} + +/// A mutable view into a single video plane. +pub struct VideoPlaneMut<'a> { + pub data: &'a mut [u8], + pub stride: usize, + pub width: u32, + pub height: u32, +} + /// A single frame or packet of raw audio data, using f32 as the internal standard. /// /// Audio samples are stored in an `Arc` for efficient zero-copy cloning when packets @@ -216,6 +422,19 @@ pub struct AudioFrame { pub metadata: Option, } +/// Custom serializer for Arc - serializes as base64 +fn serialize_arc_pooled_video_bytes( + arc: &Arc, + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + use serde::Serialize; + base64::Engine::encode(&base64::engine::general_purpose::STANDARD, arc.as_slice()) + .serialize(serializer) +} + /// Custom serializer for Arc - serializes as a slice fn serialize_arc_pooled_samples( arc: &Arc, @@ -265,6 +484,7 @@ impl AudioFrame { /// timestamp_us: Some(1000), /// duration_us: Some(20_000), /// sequence: Some(42), + /// keyframe: None, /// }; /// let frame = AudioFrame::with_metadata(48000, 2, vec![0.5, -0.5], Some(metadata)); /// assert_eq!(frame.metadata.unwrap().sequence, Some(42)); @@ -380,3 +600,207 @@ impl AudioFrame { Some((frames * 1_000_000) / u64::from(self.sample_rate)) } } + +/// A single frame of raw video data, stored in an Arc for zero-copy fan-out. +#[derive(Debug, Clone, Serialize)] +pub struct VideoFrame { + pub width: u32, + pub height: u32, + pub pixel_format: PixelFormat, + pub layout: VideoLayout, + #[serde(serialize_with = "serialize_arc_pooled_video_bytes")] + pub data: Arc, + pub metadata: Option, +} + +impl VideoFrame { + pub fn from_pooled( + width: u32, + height: u32, + pixel_format: PixelFormat, + data: PooledVideoData, + metadata: Option, + ) -> Result { + let layout = VideoLayout::packed(width, height, pixel_format); + Self::from_pooled_with_layout(width, height, pixel_format, layout, data, metadata) + } + + pub fn from_pooled_with_layout( + width: u32, + height: u32, + pixel_format: PixelFormat, + layout: VideoLayout, + mut data: PooledVideoData, + metadata: Option, + ) -> Result { + let expected_layout = + VideoLayout::aligned(width, height, pixel_format, layout.stride_align()); + if layout != expected_layout { + return Err(StreamKitError::Runtime(format!( + "VideoFrame layout mismatch: expected {expected_layout:?}, got {layout:?}" + ))); + } + if data.len() < layout.total_bytes() { + return Err(StreamKitError::Runtime(format!( + "VideoFrame data buffer too small: need {} bytes, have {}", + layout.total_bytes(), + data.len() + ))); + } + data.truncate(layout.total_bytes()); + Ok(Self { width, height, pixel_format, layout, data: Arc::new(data), metadata }) + } + + pub fn new( + width: u32, + height: u32, + pixel_format: PixelFormat, + data: Vec, + ) -> Result { + Self::from_pooled(width, height, pixel_format, PooledVideoData::from_vec(data), None) + } + + pub fn with_metadata( + width: u32, + height: u32, + pixel_format: PixelFormat, + data: Vec, + metadata: Option, + ) -> Result { + Self::from_pooled(width, height, pixel_format, PooledVideoData::from_vec(data), metadata) + } + + pub fn from_arc( + width: u32, + height: u32, + pixel_format: PixelFormat, + data: Arc, + metadata: Option, + ) -> Result { + let layout = VideoLayout::packed(width, height, pixel_format); + Self::from_arc_with_layout(width, height, pixel_format, layout, data, metadata) + } + + pub fn from_arc_with_layout( + width: u32, + height: u32, + pixel_format: PixelFormat, + layout: VideoLayout, + data: Arc, + metadata: Option, + ) -> Result { + let expected_layout = + VideoLayout::aligned(width, height, pixel_format, layout.stride_align()); + if layout != expected_layout { + return Err(StreamKitError::Runtime(format!( + "VideoFrame layout mismatch: expected {expected_layout:?}, got {layout:?}" + ))); + } + if data.len() < layout.total_bytes() { + return Err(StreamKitError::Runtime(format!( + "VideoFrame data buffer too small: need {} bytes, have {}", + layout.total_bytes(), + data.len() + ))); + } + Ok(Self { width, height, pixel_format, layout, data, metadata }) + } + + pub fn data(&self) -> &[u8] { + self.data.as_slice() + } + + pub fn make_data_mut(&mut self) -> &mut [u8] { + Arc::make_mut(&mut self.data).as_mut_slice() + } + + pub fn has_unique_data(&self) -> bool { + Arc::strong_count(&self.data) == 1 + } + + pub fn data_len(&self) -> usize { + self.data.len() + } + + #[allow(clippy::len_without_is_empty)] // is_empty provided explicitly + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + pub fn layout(&self) -> VideoLayout { + self.layout + } + + pub fn plane(&self, index: usize) -> Option> { + let layout = self.layout(); + let plane = layout.plane(index)?; + let start = plane.offset; + let end = start + plane.stride * plane.height as usize; + if end <= self.data.len() { + Some(VideoPlaneRef { + data: &self.data.as_slice()[start..end], + stride: plane.stride, + width: plane.width, + height: plane.height, + }) + } else { + None + } + } + + pub fn plane_mut(&mut self, index: usize) -> Option> { + let layout = self.layout(); + let plane = layout.plane(index)?; + let start = plane.offset; + let end = start + plane.stride * plane.height as usize; + let data = Arc::make_mut(&mut self.data); + if end <= data.len() { + Some(VideoPlaneMut { + data: &mut data.as_mut_slice()[start..end], + stride: plane.stride, + width: plane.width, + height: plane.height, + }) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::frame_pool::FramePool; + + #[test] + fn video_frame_copy_on_write() { + let frame_a = VideoFrame::new(2, 1, PixelFormat::Rgba8, vec![0u8; 8]).unwrap(); + let mut frame_b = frame_a.clone(); + + assert!(!frame_a.has_unique_data()); + assert!(!frame_b.has_unique_data()); + + frame_b.make_data_mut()[0] = 7; + + assert_eq!(frame_a.data()[0], 0); + assert_eq!(frame_b.data()[0], 7); + assert!(frame_a.has_unique_data()); + assert!(frame_b.has_unique_data()); + } + + #[test] + fn video_frame_pool_returns_on_drop() { + let pool = FramePool::::preallocated(&[8], 1); + assert_eq!(pool.stats().buckets[0].available, 1); + + { + let data = pool.get(8); + let frame = VideoFrame::from_pooled(2, 1, PixelFormat::Rgba8, data, None).unwrap(); + assert_eq!(frame.data_len(), 8); + assert_eq!(pool.stats().buckets[0].available, 0); + drop(frame); + } + + assert_eq!(pool.stats().buckets[0].available, 1); + } +} diff --git a/crates/core/src/view_data.rs b/crates/core/src/view_data.rs new file mode 100644 index 00000000..eb84d952 --- /dev/null +++ b/crates/core/src/view_data.rs @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Per-node view data channel for emitting UI-relevant structured data. +//! +//! This module provides types and helpers for nodes to emit view data +//! (e.g., resolved compositor layout) that the frontend can consume. +//! View data is best-effort and follows the same pattern as stats/telemetry. + +use std::time::SystemTime; + +/// A view data update message sent by a node to report structured UI-relevant data. +/// +/// Unlike stats (which are numeric counters), view data carries arbitrary JSON +/// that the frontend interprets per-node-type. For example, the compositor emits +/// its resolved layout so the frontend can render overlays at server-computed positions. +#[derive(Debug, Clone)] +pub struct NodeViewDataUpdate { + /// The unique identifier of the node reporting the view data + pub node_id: String, + /// Structured view data payload (node-type-specific) + pub data: serde_json::Value, + /// When this update was produced + pub timestamp: SystemTime, +} + +/// Helper functions for emitting node view data updates. +/// These functions reduce boilerplate when sending view data from nodes. +pub mod view_data_helpers { + use super::{NodeViewDataUpdate, SystemTime}; + use tokio::sync::mpsc; + + /// Emits a view data update to the provided channel. + /// + /// Best-effort: uses `try_send` so the node never blocks on view data emission. + /// Failures are silently ignored as view data is informational only. + #[inline] + pub fn emit_view_data( + tx: &Option>, + node_id: &str, + data: serde_json::Value, + ) { + if let Some(tx) = tx { + let _ = tx.try_send(NodeViewDataUpdate { + node_id: node_id.to_string(), + data, + timestamp: SystemTime::now(), + }); + } + } +} diff --git a/crates/engine/Cargo.toml b/crates/engine/Cargo.toml index 3805f845..d77fc33a 100644 --- a/crates/engine/Cargo.toml +++ b/crates/engine/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "streamkit-engine" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Claudio Costa ", "StreamKit Contributors"] description = "Pipeline execution engines for StreamKit" @@ -54,7 +54,23 @@ script = ["streamkit-nodes/script"] [dev-dependencies] serde_json = { workspace = true } -tracing-subscriber = "0.3" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +futures = { workspace = true } +bytes = { workspace = true } +streamkit-api = { workspace = true } +indexmap = { workspace = true } + +[[bench]] +name = "compositor_pipeline" +harness = false + +[[bench]] +name = "compositor_only" +harness = false + +[[bench]] +name = "pixel_convert" +harness = false [lints] workspace = true diff --git a/crates/engine/benches/compositor_only.rs b/crates/engine/benches/compositor_only.rs new file mode 100644 index 00000000..a7bb0cc0 --- /dev/null +++ b/crates/engine/benches/compositor_only.rs @@ -0,0 +1,825 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +#![allow(clippy::disallowed_macros)] // Bench binary intentionally uses eprintln!/println! for output. +#![allow(clippy::expect_used)] // Panicking on errors is fine in a benchmark binary. +#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_precision_loss)] + +//! Compositor-only microbenchmark — measures `composite_frame` in isolation +//! (no VP9 encode, no mux, no async runtime overhead). +//! +//! Exercises the following scenarios across multiple resolutions: +//! +//! - 1 layer RGBA (baseline) +//! - 2 layers RGBA (PiP) +//! - 4 layers RGBA +//! - 2 layers mixed I420 + RGBA (measures YUV→RGBA conversion overhead) +//! - 2 layers mixed NV12 + RGBA +//! - 2 layers RGBA with rotation +//! - 2 layers RGBA, static (same data each frame — for future cache-hit measurement) +//! - 1 layer RGBA + text overlay (lower-third banner) +//! - 1 layer RGBA + image overlay (logo watermark) +//! - 2 layers PiP + both overlays (realistic broadcast layout) +//! - I420 bg + PiP + both overlays (realistic codec→compositor pipeline) +//! +//! ## Usage +//! +//! Quick run (default 200 frames @ 1280×720): +//! +//! ```bash +//! cargo bench -p streamkit-engine --bench compositor_only +//! ``` +//! +//! Custom parameters: +//! +//! ```bash +//! cargo bench -p streamkit-engine --bench compositor_only -- --frames 500 --width 1920 --height 1080 +//! ``` + +use std::sync::Arc; +use std::time::Instant; + +use streamkit_core::frame_pool::PooledVideoData; +use streamkit_core::types::PixelFormat; +use streamkit_core::VideoFramePool; + +// Re-use the compositor kernel and pixel_ops directly. +use streamkit_nodes::video::compositor::config::Rect; +use streamkit_nodes::video::compositor::kernel::{composite_frame, ConversionCache, LayerSnapshot}; +use streamkit_nodes::video::compositor::overlay::DecodedOverlay; +use streamkit_nodes::video::compositor::pixel_ops::{rgba8_to_i420_buf, rgba8_to_nv12_buf}; + +// ── Default benchmark parameters ──────────────────────────────────────────── + +const DEFAULT_WIDTH: u32 = 1280; +const DEFAULT_HEIGHT: u32 = 720; +const DEFAULT_FRAME_COUNT: u32 = 200; + +// ── Arg parser ────────────────────────────────────────────────────────────── + +struct BenchArgs { + width: u32, + height: u32, + frame_count: u32, + iterations: u32, + /// Optional filter: only run scenarios whose label contains this substring. + filter: Option, +} + +impl BenchArgs { + fn parse() -> Self { + let args: Vec = std::env::args().collect(); + let mut cfg = Self { + width: DEFAULT_WIDTH, + height: DEFAULT_HEIGHT, + frame_count: DEFAULT_FRAME_COUNT, + iterations: 3, + filter: None, + }; + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--width" | "-w" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.width = v.parse().unwrap_or(cfg.width); + } + }, + "--height" | "-h" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.height = v.parse().unwrap_or(cfg.height); + } + }, + "--frames" | "-n" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.frame_count = v.parse().unwrap_or(cfg.frame_count); + } + }, + "--iterations" | "-i" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.iterations = v.parse().unwrap_or(cfg.iterations); + } + }, + "--filter" | "-f" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.filter = Some(v.clone()); + } + }, + _ => {}, + } + i += 1; + } + cfg + } +} + +// ── Frame generators ──────────────────────────────────────────────────────── + +/// Generate an RGBA8 color-bar frame (opaque, all alpha = 255). +#[allow(clippy::many_single_char_names)] +fn generate_rgba_frame(width: u32, height: u32) -> Vec { + let w = width as usize; + let h = height as usize; + let mut data = vec![0u8; w * h * 4]; + // Simple vertical gradient bars for visual distinctness. + let bar_colors: &[(u8, u8, u8)] = &[ + (191, 191, 191), // white + (191, 191, 0), // yellow + (0, 191, 191), // cyan + (0, 191, 0), // green + (191, 0, 191), // magenta + (191, 0, 0), // red + (0, 0, 191), // blue + ]; + for row in 0..h { + for col in 0..w { + let bar_idx = col * bar_colors.len() / w; + let (r, g, b) = bar_colors[bar_idx]; + let off = (row * w + col) * 4; + data[off] = r; + data[off + 1] = g; + data[off + 2] = b; + data[off + 3] = 255; + } + } + data +} + +/// Generate an I420 frame by converting an RGBA frame. +fn generate_i420_frame(width: u32, height: u32) -> Vec { + let rgba = generate_rgba_frame(width, height); + let w = width as usize; + let h = height as usize; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let i420_size = w * h + 2 * chroma_w * chroma_h; + let mut i420 = vec![0u8; i420_size]; + rgba8_to_i420_buf(&rgba, width, height, &mut i420); + i420 +} + +/// Generate an NV12 frame by converting an RGBA frame. +fn generate_nv12_frame(width: u32, height: u32) -> Vec { + let rgba = generate_rgba_frame(width, height); + let w = width as usize; + let h = height as usize; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let nv12_size = w * h + chroma_w * 2 * chroma_h; + let mut nv12 = vec![0u8; nv12_size]; + streamkit_nodes::video::compositor::pixel_ops::rgba8_to_nv12_buf( + &rgba, width, height, &mut nv12, + ); + nv12 +} + +/// Generate a semi-transparent RGBA overlay simulating rendered text. +/// +/// Produces a strip with alternating opaque "glyph" blocks and transparent +/// gaps, similar to real rasterised text bitmaps. +fn generate_text_overlay(width: u32, height: u32) -> Vec { + let w = width as usize; + let h = height as usize; + let mut data = vec![0u8; w * h * 4]; + for row in 0..h { + for col in 0..w { + let off = (row * w + col) * 4; + // Simulate glyph blocks: 60% of columns are "ink", rest transparent. + let in_glyph = (col * 5 / w).is_multiple_of(2); + if in_glyph { + data[off] = 255; // white text + data[off + 1] = 255; + data[off + 2] = 255; + data[off + 3] = 220; // slightly translucent + } + // else: stays rgba(0,0,0,0) — fully transparent gap + } + } + data +} + +/// Generate a semi-transparent RGBA overlay simulating an image/logo. +/// +/// Produces a filled rectangle with partial alpha, exercising the alpha-blend +/// code path in the compositor. +fn generate_image_overlay(width: u32, height: u32) -> Vec { + let w = width as usize; + let h = height as usize; + let mut data = vec![0u8; w * h * 4]; + for row in 0..h { + for col in 0..w { + let off = (row * w + col) * 4; + // Gradient alpha from top to bottom for realistic compositing. + let alpha = (row * 200 / h + 55).min(255) as u8; + data[off] = 60; + data[off + 1] = 120; + data[off + 2] = 200; + data[off + 3] = alpha; + } + } + data +} + +// ── Compositing harness ───────────────────────────────────────────────────── + +/// Call the real `composite_frame` kernel for `frame_count` iterations, +/// returning per-frame timing statistics. This exercises all kernel +/// optimizations: conversion cache, skip-canvas-clear, identity-scale +/// fast-path, precomputed x-map, SSE2 blend, etc. +/// +/// Uses a real `VideoFramePool` to match production behaviour (pooled buffer +/// reuse instead of per-frame heap allocation). +fn bench_composite( + _label: &str, + canvas_w: u32, + canvas_h: u32, + layers: &[Option], + image_overlays: &[Arc], + text_overlays: &[Arc], + frame_count: u32, +) -> BenchResult { + let mut conversion_cache = ConversionCache::new(); + let pool = VideoFramePool::video_default(); + + let start = Instant::now(); + + for _ in 0..frame_count { + let _result = composite_frame( + canvas_w, + canvas_h, + layers, + image_overlays, + text_overlays, + Some(&pool), + &mut conversion_cache, + ); + } + + let elapsed = start.elapsed(); + BenchResult { total_secs: elapsed.as_secs_f64(), frame_count } +} + +/// Benchmark RGBA8 → NV12 output conversion in isolation. +/// +/// Mirrors the production VP9 encoder path (`vp9.rs:1131`) where `composite_frame` +/// output feeds directly into `rgba8_to_nv12_buf`. Pre-composites a single frame, +/// then times repeated NV12 conversions from the same RGBA buffer. +fn bench_rgba_to_nv12(canvas_w: u32, canvas_h: u32, frame_count: u32) -> BenchResult { + let w = canvas_w as usize; + let h = canvas_h as usize; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + + // Pre-generate a realistic RGBA canvas (colorbar pattern, all opaque). + let rgba = generate_rgba_frame(canvas_w, canvas_h); + let nv12_size = w * h + chroma_w * 2 * chroma_h; + let mut nv12 = vec![0u8; nv12_size]; + + let start = Instant::now(); + + for _ in 0..frame_count { + rgba8_to_nv12_buf(&rgba, canvas_w, canvas_h, &mut nv12); + } + + let elapsed = start.elapsed(); + BenchResult { total_secs: elapsed.as_secs_f64(), frame_count } +} + +struct BenchResult { + total_secs: f64, + frame_count: u32, +} + +impl BenchResult { + fn fps(&self) -> f64 { + f64::from(self.frame_count) / self.total_secs + } + + fn ms_per_frame(&self) -> f64 { + self.total_secs * 1000.0 / f64::from(self.frame_count) + } +} + +// ── Scenario definitions ──────────────────────────────────────────────────── + +struct Scenario { + label: String, + layers: Vec>, + image_overlays: Vec>, + text_overlays: Vec>, +} + +#[allow(clippy::too_many_arguments, clippy::unnecessary_wraps)] +fn make_layer( + data: Vec, + width: u32, + height: u32, + pixel_format: PixelFormat, + rect: Option, + opacity: f32, + z_index: i32, + rotation_degrees: f32, +) -> Option { + Some(LayerSnapshot { + data: Arc::new(PooledVideoData::from_vec(data)), + width, + height, + pixel_format, + rect, + opacity, + z_index, + rotation_degrees, + mirror_horizontal: false, + mirror_vertical: false, + }) +} + +#[allow(clippy::too_many_lines)] +fn build_scenarios(canvas_w: u32, canvas_h: u32) -> Vec { + let pip_w = canvas_w / 3; + let pip_h = canvas_h / 3; + let pip_x = (canvas_w - pip_w - 20).cast_signed(); + let pip_y = (canvas_h - pip_h - 20).cast_signed(); + + // ── Overlay data (reused across scenarios) ────────────────────── + // Text overlay: a bottom-third banner (typical lower-third title). + let text_ov_w = canvas_w * 2 / 3; + let text_ov_h = canvas_h / 8; + let text_overlay = Arc::new(DecodedOverlay { + rgba_data: generate_text_overlay(text_ov_w, text_ov_h), + width: text_ov_w, + height: text_ov_h, + rect: Rect { + x: ((canvas_w - text_ov_w) / 2).cast_signed(), + y: (canvas_h - text_ov_h - 40).cast_signed(), + width: text_ov_w, + height: text_ov_h, + }, + opacity: 0.95, + rotation_degrees: 0.0, + z_index: 10, + }); + + // Image overlay: a corner logo watermark. + let logo_w = canvas_w / 6; + let logo_h = canvas_h / 8; + let image_overlay = Arc::new(DecodedOverlay { + rgba_data: generate_image_overlay(logo_w, logo_h), + width: logo_w, + height: logo_h, + rect: Rect { x: 20, y: 20, width: logo_w, height: logo_h }, + opacity: 0.8, + rotation_degrees: 0.0, + z_index: 11, + }); + + vec![ + // 1 layer RGBA — baseline + Scenario { + label: "1-layer-rgba".to_string(), + layers: vec![make_layer( + generate_rgba_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Rgba8, + None, + 1.0, + 0, + 0.0, + )], + image_overlays: Vec::new(), + text_overlays: Vec::new(), + }, + // 2 layers RGBA (PiP) + Scenario { + label: "2-layer-rgba-pip".to_string(), + layers: vec![ + make_layer( + generate_rgba_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Rgba8, + None, + 1.0, + 0, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 0.9, + 1, + 0.0, + ), + ], + image_overlays: Vec::new(), + text_overlays: Vec::new(), + }, + // 4 layers RGBA + Scenario { + label: "4-layer-rgba".to_string(), + layers: vec![ + make_layer( + generate_rgba_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Rgba8, + None, + 1.0, + 0, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 0.9, + 1, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: 20, y: 20, width: pip_w, height: pip_h }), + 0.8, + 2, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: 20, y: pip_y, width: pip_w, height: pip_h }), + 0.7, + 3, + 0.0, + ), + ], + image_overlays: Vec::new(), + text_overlays: Vec::new(), + }, + // 2 layers: I420 bg + RGBA PiP (measures conversion overhead) + Scenario { + label: "2-layer-i420+rgba".to_string(), + layers: vec![ + make_layer( + generate_i420_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::I420, + None, + 1.0, + 0, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 0.9, + 1, + 0.0, + ), + ], + image_overlays: Vec::new(), + text_overlays: Vec::new(), + }, + // 2 layers: NV12 bg + RGBA PiP + Scenario { + label: "2-layer-nv12+rgba".to_string(), + layers: vec![ + make_layer( + generate_nv12_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Nv12, + None, + 1.0, + 0, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 0.9, + 1, + 0.0, + ), + ], + image_overlays: Vec::new(), + text_overlays: Vec::new(), + }, + // 2 layers RGBA with rotation on PiP + Scenario { + label: "2-layer-rgba-rotated".to_string(), + layers: vec![ + make_layer( + generate_rgba_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Rgba8, + None, + 1.0, + 0, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 0.9, + 1, + 15.0, // 15° rotation + ), + ], + image_overlays: Vec::new(), + text_overlays: Vec::new(), + }, + // 2 layers RGBA, static (same Arc — for future cache-hit measurement) + Scenario { + label: "2-layer-rgba-static".to_string(), + layers: { + let bg = + Arc::new(PooledVideoData::from_vec(generate_rgba_frame(canvas_w, canvas_h))); + let pip = Arc::new(PooledVideoData::from_vec(generate_rgba_frame(pip_w, pip_h))); + vec![ + Some(LayerSnapshot { + data: bg, + width: canvas_w, + height: canvas_h, + pixel_format: PixelFormat::Rgba8, + rect: None, + opacity: 1.0, + z_index: 0, + rotation_degrees: 0.0, + mirror_horizontal: false, + mirror_vertical: false, + }), + Some(LayerSnapshot { + data: pip, + width: pip_w, + height: pip_h, + pixel_format: PixelFormat::Rgba8, + rect: Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + opacity: 0.9, + z_index: 1, + rotation_degrees: 0.0, + mirror_horizontal: false, + mirror_vertical: false, + }), + ] + }, + image_overlays: Vec::new(), + text_overlays: Vec::new(), + }, + // ── Overlay scenarios ────────────────────────────────────────── + // 1 layer RGBA + text overlay (lower-third banner) + Scenario { + label: "1-layer+text-overlay".to_string(), + layers: vec![make_layer( + generate_rgba_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Rgba8, + None, + 1.0, + 0, + 0.0, + )], + image_overlays: Vec::new(), + text_overlays: vec![Arc::clone(&text_overlay)], + }, + // 1 layer RGBA + image overlay (logo watermark) + Scenario { + label: "1-layer+img-overlay".to_string(), + layers: vec![make_layer( + generate_rgba_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Rgba8, + None, + 1.0, + 0, + 0.0, + )], + image_overlays: vec![Arc::clone(&image_overlay)], + text_overlays: Vec::new(), + }, + // 2 layers PiP + both overlays (realistic broadcast layout) + Scenario { + label: "2-layer-pip+overlays".to_string(), + layers: vec![ + make_layer( + generate_rgba_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::Rgba8, + None, + 1.0, + 0, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 0.9, + 1, + 0.0, + ), + ], + image_overlays: vec![Arc::clone(&image_overlay)], + text_overlays: vec![Arc::clone(&text_overlay)], + }, + // I420 bg + PiP + both overlays (realistic codec→compositor pipeline) + Scenario { + label: "i420+pip+overlays".to_string(), + layers: vec![ + make_layer( + generate_i420_frame(canvas_w, canvas_h), + canvas_w, + canvas_h, + PixelFormat::I420, + None, + 1.0, + 0, + 0.0, + ), + make_layer( + generate_rgba_frame(pip_w, pip_h), + pip_w, + pip_h, + PixelFormat::Rgba8, + Some(Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 0.9, + 1, + 0.0, + ), + ], + image_overlays: vec![Arc::clone(&image_overlay)], + text_overlays: vec![Arc::clone(&text_overlay)], + }, + ] +} + +// ── Main ──────────────────────────────────────────────────────────────────── + +fn main() { + let args = BenchArgs::parse(); + + let resolutions: &[(u32, u32)] = if args.width == DEFAULT_WIDTH && args.height == DEFAULT_HEIGHT + { + // Default: run at multiple resolutions. + &[(640, 480), (1280, 720), (1920, 1080)] + } else { + // Custom: run at the specified resolution only. + // (Leak to get 'static — acceptable in a short-lived bench binary.) + let res = Box::leak(Box::new([(args.width, args.height)])); + res + }; + + eprintln!("╔══════════════════════════════════════════════════════════╗"); + eprintln!("║ Compositor-Only Microbenchmark ║"); + eprintln!("╠══════════════════════════════════════════════════════════╣"); + eprintln!( + "║ Resolutions : {:<41}║", + resolutions.iter().map(|(w, h)| format!("{w}×{h}")).collect::>().join(", ") + ); + eprintln!("║ Frames : {:<41}║", args.frame_count); + eprintln!("║ Iterations : {:<41}║", args.iterations); + if let Some(ref f) = args.filter { + eprintln!("║ Filter : {f:<41}║"); + } + eprintln!("╚══════════════════════════════════════════════════════════╝"); + eprintln!(); + + let mut json_results: Vec = Vec::new(); + + for &(w, h) in resolutions { + eprintln!("── {w}×{h} ──────────────────────────────────────────────"); + + let scenarios = build_scenarios(w, h); + + for scenario in &scenarios { + if let Some(ref filter) = args.filter { + if !scenario.label.contains(filter.as_str()) { + continue; + } + } + + let mut iter_results = Vec::with_capacity(args.iterations as usize); + + for iter in 1..=args.iterations { + let result = bench_composite( + &scenario.label, + w, + h, + &scenario.layers, + &scenario.image_overlays, + &scenario.text_overlays, + args.frame_count, + ); + eprintln!( + " {:<28} iter {iter}/{}: {:>8.1} fps ({:.2} ms/frame)", + scenario.label, + args.iterations, + result.fps(), + result.ms_per_frame(), + ); + iter_results.push(result); + } + + // Summary for this scenario. + let fps_values: Vec = iter_results.iter().map(BenchResult::fps).collect(); + let ms_values: Vec = iter_results.iter().map(BenchResult::ms_per_frame).collect(); + let mean_fps = fps_values.iter().sum::() / fps_values.len() as f64; + let mean_ms = ms_values.iter().sum::() / ms_values.len() as f64; + let min_ms = ms_values.iter().copied().fold(f64::INFINITY, f64::min); + let max_ms = ms_values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + + eprintln!( + " {:<28} avg: {:>8.1} fps ({:.2} ms/frame, min={:.2}, max={:.2})", + "", mean_fps, mean_ms, min_ms, max_ms, + ); + + json_results.push(serde_json::json!({ + "benchmark": "compositor_only", + "scenario": scenario.label, + "width": w, + "height": h, + "frame_count": args.frame_count, + "iterations": args.iterations, + "mean_fps": mean_fps, + "mean_ms_per_frame": mean_ms, + "min_ms_per_frame": min_ms, + "max_ms_per_frame": max_ms, + })); + } + + // ── Standalone conversion benchmarks ────────────────────────── + let conversion_label = "rgba-to-nv12-output"; + if args.filter.as_ref().is_none_or(|f| conversion_label.contains(f.as_str())) { + let mut iter_results = Vec::with_capacity(args.iterations as usize); + for iter in 1..=args.iterations { + let result = bench_rgba_to_nv12(w, h, args.frame_count); + eprintln!( + " {:<28} iter {iter}/{}: {:>8.1} fps ({:.2} ms/frame)", + conversion_label, + args.iterations, + result.fps(), + result.ms_per_frame(), + ); + iter_results.push(result); + } + let fps_values: Vec = iter_results.iter().map(BenchResult::fps).collect(); + let ms_values: Vec = iter_results.iter().map(BenchResult::ms_per_frame).collect(); + let mean_fps = fps_values.iter().sum::() / fps_values.len() as f64; + let mean_ms = ms_values.iter().sum::() / ms_values.len() as f64; + let min_ms = ms_values.iter().copied().fold(f64::INFINITY, f64::min); + let max_ms = ms_values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + eprintln!( + " {:<28} avg: {:>8.1} fps ({:.2} ms/frame, min={:.2}, max={:.2})", + "", mean_fps, mean_ms, min_ms, max_ms, + ); + json_results.push(serde_json::json!({ + "benchmark": "compositor_only", + "scenario": conversion_label, + "width": w, + "height": h, + "frame_count": args.frame_count, + "iterations": args.iterations, + "mean_fps": mean_fps, + "mean_ms_per_frame": mean_ms, + "min_ms_per_frame": min_ms, + "max_ms_per_frame": max_ms, + })); + } + + eprintln!(); + } + + // Machine-readable JSON output. + println!("{}", serde_json::to_string_pretty(&json_results).expect("JSON serialization")); +} diff --git a/crates/engine/benches/compositor_pipeline.rs b/crates/engine/benches/compositor_pipeline.rs new file mode 100644 index 00000000..5252bc48 --- /dev/null +++ b/crates/engine/benches/compositor_pipeline.rs @@ -0,0 +1,447 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +#![allow(clippy::disallowed_macros)] // Bench binary intentionally uses eprintln!/println! for output. +#![allow(clippy::expect_used)] // Panicking on errors is fine in a benchmark binary. +//! Benchmark for the compositing oneshot pipeline. +//! +//! Runs the same graph as `samples/pipelines/oneshot/video_compositor_demo.yml`: +//! +//! colorbars_bg (RGBA8) ──┐ +//! ├─► compositor ──► vp9_encoder ──► http_output +//! colorbars_pip (RGBA8) ┘ +//! +//! The benchmark drives the pipeline through [`Engine::run_oneshot_pipeline`] +//! and reports wall-clock time, throughput (frames/s), and total output bytes. +//! +//! ## Usage +//! +//! Quick run (default 90 frames @ 640×480): +//! +//! ```bash +//! cargo bench -p streamkit-engine --bench compositor_pipeline +//! ``` +//! +//! Custom frame count / resolution for profiling: +//! +//! ```bash +//! cargo bench -p streamkit-engine --bench compositor_pipeline -- --frames 300 --width 1280 --height 720 +//! ``` +//! +//! Attach a profiler (e.g. `perf`, `samply`, `cargo flamegraph`): +//! +//! ```bash +//! cargo build --release -p streamkit-engine --bench compositor_pipeline +//! samply record target/release/deps/compositor_pipeline-* -- --frames 300 +//! ``` + +use std::time::Instant; +use streamkit_engine::Engine; + +/// Default benchmark parameters (matches the sample pipeline). +const DEFAULT_WIDTH: u32 = 640; +const DEFAULT_HEIGHT: u32 = 480; +const DEFAULT_FPS: u32 = 30; +const DEFAULT_FRAME_COUNT: u32 = 90; + +/// Simple arg parser — not worth pulling in clap for a bench binary. +struct BenchArgs { + width: u32, + height: u32, + fps: u32, + frame_count: u32, + iterations: u32, +} + +impl BenchArgs { + fn parse() -> Self { + let args: Vec = std::env::args().collect(); + let mut cfg = Self { + width: DEFAULT_WIDTH, + height: DEFAULT_HEIGHT, + fps: DEFAULT_FPS, + frame_count: DEFAULT_FRAME_COUNT, + iterations: 3, + }; + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--width" | "-w" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.width = v.parse().unwrap_or(cfg.width); + } + }, + "--height" | "-h" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.height = v.parse().unwrap_or(cfg.height); + } + }, + "--fps" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.fps = v.parse().unwrap_or(cfg.fps); + } + }, + "--frames" | "-n" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.frame_count = v.parse().unwrap_or(cfg.frame_count); + } + }, + "--iterations" | "-i" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.iterations = v.parse().unwrap_or(cfg.iterations); + } + }, + _ => {}, // ignore unknown (cargo bench passes extra flags) + } + i += 1; + } + cfg + } +} + +/// Build the compositor demo pipeline definition programmatically. +/// +/// Mirrors `samples/pipelines/oneshot/video_compositor_demo.yml` but with +/// configurable resolution and frame count. +fn build_pipeline(width: u32, height: u32, fps: u32, frame_count: u32) -> streamkit_api::Pipeline { + use streamkit_api::{Connection, EngineMode, Node, Pipeline}; + + let mut nodes = indexmap::IndexMap::new(); + + // --- colorbars_bg (NV12, full-size) --- + // Uses NV12 to exercise the NV12→RGBA8 conversion path in the compositor, + // matching real pipelines where camera inputs are typically NV12. + nodes.insert( + "colorbars_bg".to_string(), + Node { + kind: "video::colorbars".to_string(), + params: Some(serde_json::json!({ + "width": width, + "height": height, + "fps": fps, + "frame_count": frame_count, + "pixel_format": "nv12", + "draw_time": true, + "animate": true, + })), + state: None, + }, + ); + + // --- colorbars_pip (NV12, half-size PiP) --- + nodes.insert( + "colorbars_pip".to_string(), + Node { + kind: "video::colorbars".to_string(), + params: Some(serde_json::json!({ + "width": width / 2, + "height": height / 2, + "fps": fps, + "frame_count": frame_count, + "pixel_format": "nv12", + "draw_time": true, + "animate": true, + })), + state: None, + }, + ); + + // --- compositor --- + nodes.insert( + "compositor".to_string(), + Node { + kind: "video::compositor".to_string(), + params: Some(serde_json::json!({ + "width": width, + "height": height, + "num_inputs": 2, + })), + state: None, + }, + ); + + // --- pixel_convert (RGBA8 → NV12) --- + nodes.insert( + "pixel_convert".to_string(), + Node { + kind: "video::pixel_convert".to_string(), + params: Some(serde_json::json!({ "output_format": "nv12" })), + state: None, + }, + ); + + // --- VP9 encoder --- + nodes.insert( + "vp9_encoder".to_string(), + Node { kind: "video::vp9::encoder".to_string(), params: None, state: None }, + ); + + // --- WebM muxer (converts encoded video to binary bytes) --- + nodes.insert( + "webm_muxer".to_string(), + Node { + kind: "containers::webm::muxer".to_string(), + params: Some(serde_json::json!({ + "video_width": width, + "video_height": height, + "streaming_mode": "live", + })), + state: None, + }, + ); + + // --- http_output (bytes sink) --- + nodes.insert( + "http_output".to_string(), + Node { kind: "streamkit::http_output".to_string(), params: None, state: None }, + ); + + let connections = vec![ + Connection { + from_node: "colorbars_bg".to_string(), + from_pin: "out".to_string(), + to_node: "compositor".to_string(), + to_pin: "in_0".to_string(), + mode: streamkit_api::ConnectionMode::Reliable, + }, + Connection { + from_node: "colorbars_pip".to_string(), + from_pin: "out".to_string(), + to_node: "compositor".to_string(), + to_pin: "in_1".to_string(), + mode: streamkit_api::ConnectionMode::Reliable, + }, + Connection { + from_node: "compositor".to_string(), + from_pin: "out".to_string(), + to_node: "pixel_convert".to_string(), + to_pin: "in".to_string(), + mode: streamkit_api::ConnectionMode::Reliable, + }, + Connection { + from_node: "pixel_convert".to_string(), + from_pin: "out".to_string(), + to_node: "vp9_encoder".to_string(), + to_pin: "in".to_string(), + mode: streamkit_api::ConnectionMode::Reliable, + }, + Connection { + from_node: "vp9_encoder".to_string(), + from_pin: "out".to_string(), + to_node: "webm_muxer".to_string(), + to_pin: "in".to_string(), + mode: streamkit_api::ConnectionMode::Reliable, + }, + Connection { + from_node: "webm_muxer".to_string(), + from_pin: "out".to_string(), + to_node: "http_output".to_string(), + to_pin: "in".to_string(), + mode: streamkit_api::ConnectionMode::Reliable, + }, + ]; + + Pipeline { + name: Some("Compositor Benchmark".to_string()), + description: Some(format!("Benchmark: {width}×{height} @ {fps} fps, {frame_count} frames")), + mode: EngineMode::OneShot, + nodes, + connections, + } +} + +/// Result of a single benchmark iteration. +struct IterResult { + elapsed: std::time::Duration, + total_bytes: usize, + chunk_count: usize, + /// First few bytes of output for header validation. + header_bytes: Vec, +} + +/// WebM/EBML magic bytes: Element ID 0x1A45DFA3. +const EBML_MAGIC: [u8; 4] = [0x1A, 0x45, 0xDF, 0xA3]; + +/// Run one iteration of the benchmark pipeline and return detailed results. +async fn run_once( + engine: &Engine, + width: u32, + height: u32, + fps: u32, + frame_count: u32, +) -> IterResult { + let definition = build_pipeline(width, height, fps, frame_count); + + let start = Instant::now(); + + let result = engine + .run_oneshot_pipeline::>, std::io::Error>( + definition, + vec![], // no HTTP inputs — generator mode + None, // default config + None, // no cancellation + ) + .await + .expect("Pipeline should start successfully"); + + // Drain all output bytes, capturing header and counting chunks. + let mut total_bytes: usize = 0; + let mut chunk_count: usize = 0; + let mut header_bytes: Vec = Vec::new(); + let mut data_stream = result.data_stream; + while let Some(chunk) = data_stream.recv().await { + if header_bytes.len() < 4 { + let need = (4 - header_bytes.len()).min(chunk.len()); + header_bytes.extend_from_slice(&chunk[..need]); + } + total_bytes += chunk.len(); + chunk_count += 1; + } + + let elapsed = start.elapsed(); + IterResult { elapsed, total_bytes, chunk_count, header_bytes } +} + +fn main() { + // Initialise a minimal tracing subscriber so nodes don't panic on log calls. + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("warn")), + ) + .init(); + + let args = BenchArgs::parse(); + + eprintln!("╔══════════════════════════════════════════════════════════╗"); + eprintln!("║ Compositor Pipeline Benchmark ║"); + eprintln!("╠══════════════════════════════════════════════════════════╣"); + eprintln!("║ Resolution : {}×{:<36}║", args.width, format!("{}", args.height)); + eprintln!("║ FPS : {:<42}║", args.fps); + eprintln!("║ Frames : {:<42}║", args.frame_count); + eprintln!("║ Iterations : {:<42}║", args.iterations); + eprintln!("╚══════════════════════════════════════════════════════════╝"); + eprintln!(); + + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("Failed to build Tokio runtime"); + + let engine = Engine::without_plugins(); + + let mut durations = Vec::with_capacity(args.iterations as usize); + let mut output_bytes_all = Vec::with_capacity(args.iterations as usize); + let mut valid_header = true; + let mut valid_size = true; + + // Minimum expected output: at least 100 bytes per frame for animated 720p VP9. + // Static colorbars compressed to ~35 bytes/frame; animated content should be + // much larger. This threshold is deliberately conservative. + let min_expected_bytes = args.frame_count as usize * 100; + + for iter in 1..=args.iterations { + let r = rt.block_on(run_once(&engine, args.width, args.height, args.fps, args.frame_count)); + + let fps_actual = f64::from(args.frame_count) / r.elapsed.as_secs_f64(); + + eprintln!( + " iter {iter}/{}: {:.3}s ({:.1} fps) output={} bytes chunks={}", + args.iterations, + r.elapsed.as_secs_f64(), + fps_actual, + r.total_bytes, + r.chunk_count, + ); + + // Validate EBML header. + if r.header_bytes.len() < 4 || r.header_bytes[..4] != EBML_MAGIC { + valid_header = false; + eprintln!( + " ⚠ EBML header mismatch: got {:?}, expected {:?}", + &r.header_bytes[..r.header_bytes.len().min(4)], + EBML_MAGIC, + ); + } + + // Validate output size is reasonable. + if r.total_bytes < min_expected_bytes { + valid_size = false; + eprintln!( + " ⚠ Output too small: {} bytes < {} expected minimum ({} bytes/frame)", + r.total_bytes, + min_expected_bytes, + r.total_bytes / args.frame_count.max(1) as usize, + ); + } + + durations.push(r.elapsed); + output_bytes_all.push(r.total_bytes); + } + + // --- Summary --- + eprintln!(); + let total_secs: Vec = durations.iter().map(std::time::Duration::as_secs_f64).collect(); + #[allow(clippy::cast_precision_loss)] + let mean = total_secs.iter().sum::() / total_secs.len() as f64; + let min = total_secs.iter().copied().fold(f64::INFINITY, f64::min); + let max = total_secs.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let stddev = if total_secs.len() > 1 { + #[allow(clippy::cast_precision_loss)] + let variance = total_secs.iter().map(|t| (t - mean).powi(2)).sum::() + / (total_secs.len() - 1) as f64; + variance.sqrt() + } else { + 0.0 + }; + + let mean_fps = f64::from(args.frame_count) / mean; + let mean_frame_ms = mean * 1000.0 / f64::from(args.frame_count); + let avg_output = output_bytes_all.iter().sum::() / output_bytes_all.len(); + let avg_bytes_per_frame = avg_output / args.frame_count.max(1) as usize; + + eprintln!("── Summary ({} iterations) ──────────────────────────────", args.iterations); + eprintln!(" wall-clock : {mean:.3}s (min={min:.3}s max={max:.3}s σ={stddev:.4}s)"); + eprintln!(" throughput : {mean_fps:.1} fps"); + eprintln!(" per-frame : {mean_frame_ms:.2} ms/frame"); + eprintln!(" output size : {avg_output} bytes (avg, {avg_bytes_per_frame} bytes/frame)"); + eprintln!( + " validation : header={} size={}", + if valid_header { "OK" } else { "FAIL" }, + if valid_size { "OK" } else { "FAIL" }, + ); + eprintln!(); + + if !valid_header || !valid_size { + eprintln!("ERROR: Output validation failed — benchmark results may be unreliable."); + eprintln!(); + std::process::exit(1); + } + + // Also print a machine-readable JSON line for CI / automated collection. + let json = serde_json::json!({ + "benchmark": "compositor_pipeline", + "width": args.width, + "height": args.height, + "fps": args.fps, + "frame_count": args.frame_count, + "iterations": args.iterations, + "mean_secs": mean, + "min_secs": min, + "max_secs": max, + "stddev_secs": stddev, + "mean_fps": mean_fps, + "mean_frame_ms": mean_frame_ms, + "avg_output_bytes": avg_output, + "avg_bytes_per_frame": avg_bytes_per_frame, + "valid_header": valid_header, + "valid_size": valid_size, + }); + println!("{json}"); +} diff --git a/crates/engine/benches/pixel_convert.rs b/crates/engine/benches/pixel_convert.rs new file mode 100644 index 00000000..151a6d21 --- /dev/null +++ b/crates/engine/benches/pixel_convert.rs @@ -0,0 +1,452 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +#![allow(clippy::disallowed_macros)] // Bench binary intentionally uses eprintln!/println! for output. +#![allow(clippy::expect_used)] // Panicking on errors is fine in a benchmark binary. +#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_precision_loss)] + +//! Pixel-format conversion microbenchmark — measures raw conversion throughput +//! for the `video::pixel_convert` node's supported conversion paths in isolation +//! (no async runtime, no channel overhead). +//! +//! Exercises the following conversions across multiple resolutions: +//! +//! - RGBA8 → NV12 +//! - RGBA8 → I420 +//! - NV12 → RGBA8 +//! - I420 → RGBA8 +//! +//! ## Usage +//! +//! Quick run (default 200 frames @ 1280×720): +//! +//! ```bash +//! cargo bench -p streamkit-engine --bench pixel_convert +//! ``` +//! +//! Custom parameters: +//! +//! ```bash +//! cargo bench -p streamkit-engine --bench pixel_convert -- --frames 300 --width 1920 --height 1080 +//! ``` + +use std::time::Instant; + +use streamkit_nodes::video::compositor::pixel_ops::{ + i420_to_rgba8_buf, nv12_to_rgba8_buf, rgba8_to_i420_buf, rgba8_to_nv12_buf, +}; + +// ── Default benchmark parameters ──────────────────────────────────────────── + +const DEFAULT_WIDTH: u32 = 1280; +const DEFAULT_HEIGHT: u32 = 720; +const DEFAULT_FRAME_COUNT: u32 = 200; + +// ── Arg parser ────────────────────────────────────────────────────────────── + +struct BenchArgs { + width: u32, + height: u32, + frame_count: u32, + iterations: u32, + /// Optional filter: only run scenarios whose label contains this substring. + filter: Option, +} + +impl BenchArgs { + fn parse() -> Self { + let args: Vec = std::env::args().collect(); + let mut cfg = Self { + width: DEFAULT_WIDTH, + height: DEFAULT_HEIGHT, + frame_count: DEFAULT_FRAME_COUNT, + iterations: 3, + filter: None, + }; + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--width" | "-w" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.width = v.parse().unwrap_or(cfg.width); + } + }, + "--height" | "-h" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.height = v.parse().unwrap_or(cfg.height); + } + }, + "--frames" | "-n" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.frame_count = v.parse().unwrap_or(cfg.frame_count); + } + }, + "--iterations" | "-i" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.iterations = v.parse().unwrap_or(cfg.iterations); + } + }, + "--filter" | "-f" => { + i += 1; + if let Some(v) = args.get(i) { + cfg.filter = Some(v.clone()); + } + }, + _ => {}, + } + i += 1; + } + cfg + } +} + +// ── Frame generators ──────────────────────────────────────────────────────── + +/// Generate an RGBA8 color-bar frame (opaque, all alpha = 255). +fn generate_rgba_frame(width: u32, height: u32) -> Vec { + let w = width as usize; + let h = height as usize; + let mut data = vec![0u8; w * h * 4]; + let bar_colors: &[(u8, u8, u8)] = &[ + (191, 191, 191), // white + (191, 191, 0), // yellow + (0, 191, 191), // cyan + (0, 191, 0), // green + (191, 0, 191), // magenta + (191, 0, 0), // red + (0, 0, 191), // blue + ]; + for row in 0..h { + for col in 0..w { + let bar_idx = col * bar_colors.len() / w; + let (r, g, b) = bar_colors[bar_idx]; + let off = (row * w + col) * 4; + data[off] = r; + data[off + 1] = g; + data[off + 2] = b; + data[off + 3] = 255; + } + } + data +} + +/// Generate an NV12 frame by converting an RGBA frame. +fn generate_nv12_frame(width: u32, height: u32) -> Vec { + let rgba = generate_rgba_frame(width, height); + let w = width as usize; + let h = height as usize; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let nv12_size = w * h + chroma_w * 2 * chroma_h; + let mut nv12 = vec![0u8; nv12_size]; + rgba8_to_nv12_buf(&rgba, width, height, &mut nv12); + nv12 +} + +/// Generate an I420 frame by converting an RGBA frame. +fn generate_i420_frame(width: u32, height: u32) -> Vec { + let rgba = generate_rgba_frame(width, height); + let w = width as usize; + let h = height as usize; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let i420_size = w * h + 2 * chroma_w * chroma_h; + let mut i420 = vec![0u8; i420_size]; + rgba8_to_i420_buf(&rgba, width, height, &mut i420); + i420 +} + +// ── Benchmark harness ─────────────────────────────────────────────────────── + +struct BenchResult { + total_secs: f64, + frame_count: u32, +} + +impl BenchResult { + fn fps(&self) -> f64 { + f64::from(self.frame_count) / self.total_secs + } + + fn ms_per_frame(&self) -> f64 { + self.total_secs * 1000.0 / f64::from(self.frame_count) + } +} + +/// Benchmark a conversion function by running it `frame_count` times on a +/// pre-allocated input/output buffer pair (warm cache — same data every frame). +fn bench_conversion( + input: &[u8], + output: &mut [u8], + width: u32, + height: u32, + frame_count: u32, + convert_fn: fn(&[u8], u32, u32, &mut [u8]), +) -> BenchResult { + // Warm-up: run once to prime caches / JIT / rayon thread pool. + convert_fn(input, width, height, output); + + let start = Instant::now(); + for _ in 0..frame_count { + convert_fn(input, width, height, output); + } + let elapsed = start.elapsed(); + + BenchResult { total_secs: elapsed.as_secs_f64(), frame_count } +} + +/// Benchmark with cold cache: pre-allocate `frame_count` unique input buffers +/// so each frame reads from a different memory region, defeating L3 caching. +/// This simulates real pipelines where every frame is fresh camera data. +fn bench_conversion_cold( + template: &[u8], + output_size: usize, + width: u32, + height: u32, + frame_count: u32, + convert_fn: fn(&[u8], u32, u32, &mut [u8]), +) -> BenchResult { + // Pre-allocate unique input buffers with slightly different data to + // prevent OS page dedup. Each buffer is ~3.5 MB at 720p RGBA8. + let inputs: Vec> = (0..frame_count) + .map(|i| { + let mut buf = template.to_vec(); + // Tweak first byte so pages differ. + buf[0] = (i & 0xFF) as u8; + buf + }) + .collect(); + let mut output = vec![0u8; output_size]; + + if frame_count == 0 { + return BenchResult { total_secs: 0.0, frame_count }; + } + + // Warm-up: prime rayon thread pool only (not data cache). + convert_fn(&inputs[0], width, height, &mut output); + + // Flush the data cache by touching a large dummy allocation. + let flush_size = 32 * 1024 * 1024; // 32 MB > L3 + let flush: Vec = vec![1u8; flush_size]; + std::hint::black_box(&flush); + drop(flush); + + let start = Instant::now(); + for i in 0..frame_count as usize { + convert_fn(&inputs[i], width, height, &mut output); + } + let elapsed = start.elapsed(); + + BenchResult { total_secs: elapsed.as_secs_f64(), frame_count } +} + +// ── Conversion scenarios ──────────────────────────────────────────────────── + +struct ConversionScenario { + label: &'static str, + input: Vec, + output_size: usize, + convert_fn: fn(&[u8], u32, u32, &mut [u8]), +} + +fn build_scenarios(width: u32, height: u32) -> Vec { + let w = width as usize; + let h = height as usize; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let rgba_size = w * h * 4; + let nv12_size = w * h + chroma_w * 2 * chroma_h; + let i420_size = w * h + 2 * chroma_w * chroma_h; + + vec![ + ConversionScenario { + label: "rgba8-to-nv12", + input: generate_rgba_frame(width, height), + output_size: nv12_size, + convert_fn: rgba8_to_nv12_buf, + }, + ConversionScenario { + label: "rgba8-to-i420", + input: generate_rgba_frame(width, height), + output_size: i420_size, + convert_fn: rgba8_to_i420_buf, + }, + ConversionScenario { + label: "nv12-to-rgba8", + input: generate_nv12_frame(width, height), + output_size: rgba_size, + convert_fn: nv12_to_rgba8_buf, + }, + ConversionScenario { + label: "i420-to-rgba8", + input: generate_i420_frame(width, height), + output_size: rgba_size, + convert_fn: i420_to_rgba8_buf, + }, + ] +} + +// ── Main ──────────────────────────────────────────────────────────────────── + +fn main() { + let args = BenchArgs::parse(); + + let resolutions: &[(u32, u32)] = if args.width == DEFAULT_WIDTH && args.height == DEFAULT_HEIGHT + { + // Default: run at multiple resolutions. + &[(640, 480), (1280, 720), (1920, 1080)] + } else { + // Custom: run at the specified resolution only. + let res = Box::leak(Box::new([(args.width, args.height)])); + res + }; + + eprintln!("╔══════════════════════════════════════════════════════════╗"); + eprintln!("║ Pixel Convert Microbenchmark ║"); + eprintln!("╠══════════════════════════════════════════════════════════╣"); + eprintln!( + "║ Resolutions : {:<41}║", + resolutions.iter().map(|(w, h)| format!("{w}×{h}")).collect::>().join(", ") + ); + eprintln!("║ Frames : {:<41}║", args.frame_count); + eprintln!("║ Iterations : {:<41}║", args.iterations); + if let Some(ref f) = args.filter { + eprintln!("║ Filter : {f:<41}║"); + } + eprintln!("╚══════════════════════════════════════════════════════════╝"); + eprintln!(); + + let mut json_results: Vec = Vec::new(); + + for &(w, h) in resolutions { + eprintln!("── {w}×{h} ──────────────────────────────────────────────"); + + let scenarios = build_scenarios(w, h); + + for scenario in &scenarios { + if let Some(ref filter) = args.filter { + if !scenario.label.contains(filter.as_str()) { + continue; + } + } + + let mut iter_results = Vec::with_capacity(args.iterations as usize); + + for iter in 1..=args.iterations { + let mut output = vec![0u8; scenario.output_size]; + let result = bench_conversion( + &scenario.input, + &mut output, + w, + h, + args.frame_count, + scenario.convert_fn, + ); + eprintln!( + " {:<28} iter {iter}/{}: {:>8.1} fps ({:.3} ms/frame)", + scenario.label, + args.iterations, + result.fps(), + result.ms_per_frame(), + ); + iter_results.push(result); + } + + // Summary for this scenario. + let fps_values: Vec = iter_results.iter().map(BenchResult::fps).collect(); + let ms_values: Vec = iter_results.iter().map(BenchResult::ms_per_frame).collect(); + let mean_fps = fps_values.iter().sum::() / fps_values.len() as f64; + let mean_ms = ms_values.iter().sum::() / ms_values.len() as f64; + let min_ms = ms_values.iter().copied().fold(f64::INFINITY, f64::min); + let max_ms = ms_values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + + eprintln!( + " {:<28} avg: {:>8.1} fps ({:.3} ms/frame, min={:.3}, max={:.3})", + "", mean_fps, mean_ms, min_ms, max_ms, + ); + + json_results.push(serde_json::json!({ + "benchmark": "pixel_convert", + "scenario": scenario.label, + "width": w, + "height": h, + "frame_count": args.frame_count, + "iterations": args.iterations, + "mean_fps": mean_fps, + "mean_ms_per_frame": mean_ms, + "min_ms_per_frame": min_ms, + "max_ms_per_frame": max_ms, + })); + } + + // ── Cold-cache variant ──────────────────────────────────────── + // Use unique per-frame buffers to simulate real pipeline behaviour + // where each frame is fresh camera data not in CPU cache. + eprintln!(" (cold cache)"); + for scenario in &scenarios { + if let Some(ref filter) = args.filter { + if !scenario.label.contains(filter.as_str()) { + continue; + } + } + + let cold_label = format!("{} (cold)", scenario.label); + let mut iter_results = Vec::with_capacity(args.iterations as usize); + + for iter in 1..=args.iterations { + let result = bench_conversion_cold( + &scenario.input, + scenario.output_size, + w, + h, + args.frame_count, + scenario.convert_fn, + ); + eprintln!( + " {:<28} iter {iter}/{}: {:>8.1} fps ({:.3} ms/frame)", + cold_label, + args.iterations, + result.fps(), + result.ms_per_frame(), + ); + iter_results.push(result); + } + + let fps_values: Vec = iter_results.iter().map(BenchResult::fps).collect(); + let ms_values: Vec = iter_results.iter().map(BenchResult::ms_per_frame).collect(); + let mean_fps = fps_values.iter().sum::() / fps_values.len() as f64; + let mean_ms = ms_values.iter().sum::() / ms_values.len() as f64; + let min_ms = ms_values.iter().copied().fold(f64::INFINITY, f64::min); + let max_ms = ms_values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + + eprintln!( + " {:<28} avg: {:>8.1} fps ({:.3} ms/frame, min={:.3}, max={:.3})", + "", mean_fps, mean_ms, min_ms, max_ms, + ); + + json_results.push(serde_json::json!({ + "benchmark": "pixel_convert", + "scenario": cold_label, + "width": w, + "height": h, + "frame_count": args.frame_count, + "iterations": args.iterations, + "mean_fps": mean_fps, + "mean_ms_per_frame": mean_ms, + "min_ms_per_frame": min_ms, + "max_ms_per_frame": max_ms, + "cache": "cold", + })); + } + + eprintln!(); + } + + // Machine-readable JSON output. + println!("{}", serde_json::to_string_pretty(&json_results).expect("JSON serialization")); +} diff --git a/crates/engine/src/dynamic_actor.rs b/crates/engine/src/dynamic_actor.rs index efff4b5c..41ddc430 100644 --- a/crates/engine/src/dynamic_actor.rs +++ b/crates/engine/src/dynamic_actor.rs @@ -20,13 +20,14 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; use streamkit_core::control::{EngineControlMessage, NodeControlMessage}; use streamkit_core::error::StreamKitError; -use streamkit_core::frame_pool::AudioFramePool; +use streamkit_core::frame_pool::{AudioFramePool, VideoFramePool}; use streamkit_core::node::{InitContext, NodeContext, OutputRouting, OutputSender}; use streamkit_core::pins::PinUpdate; use streamkit_core::registry::NodeRegistry; use streamkit_core::state::{NodeState, NodeStateUpdate}; use streamkit_core::stats::{NodeStats, NodeStatsUpdate}; use streamkit_core::telemetry::TelemetryEvent; +use streamkit_core::view_data::NodeViewDataUpdate; use streamkit_core::PinCardinality; use tokio::sync::mpsc; use tracing::Instrument; @@ -60,6 +61,8 @@ pub struct DynamicEngine { pub(super) session_id: Option, /// Per-pipeline audio buffer pool for hot paths (e.g., Opus decode). pub(super) audio_pool: std::sync::Arc, + /// Per-pipeline video buffer pool for hot paths (e.g., video decode). + pub(super) video_pool: std::sync::Arc, /// Buffer capacity for node input channels pub(super) node_input_capacity: usize, /// Buffer capacity for pin distributor channels @@ -74,6 +77,10 @@ pub struct DynamicEngine { pub(super) stats_subscribers: Vec>, /// Subscribers that want to receive telemetry events pub(super) telemetry_subscribers: Vec>, + /// Latest view data per node (e.g., compositor resolved layout) + pub(super) node_view_data: HashMap, + /// Subscribers that want to receive node view data updates + pub(super) view_data_subscribers: Vec>, // Metrics pub(super) nodes_active_gauge: opentelemetry::metrics::Gauge, pub(super) node_state_transitions_counter: opentelemetry::metrics::Counter, @@ -105,11 +112,12 @@ impl DynamicEngine { let (state_tx, mut state_rx) = mpsc::channel(DEFAULT_SUBSCRIBER_CHANNEL_CAPACITY); let (stats_tx, mut stats_rx) = mpsc::channel(DEFAULT_SUBSCRIBER_CHANNEL_CAPACITY); let (telemetry_tx, mut telemetry_rx) = mpsc::channel(DEFAULT_SUBSCRIBER_CHANNEL_CAPACITY); + let (view_data_tx, mut view_data_rx) = mpsc::channel(DEFAULT_SUBSCRIBER_CHANNEL_CAPACITY); loop { tokio::select! { Some(control_msg) = self.control_rx.recv() => { - if !self.handle_engine_control(control_msg, &state_tx, &stats_tx, &telemetry_tx).await { + if !self.handle_engine_control(control_msg, &state_tx, &stats_tx, &telemetry_tx, &view_data_tx).await { break; // Shutdown requested } }, @@ -126,6 +134,9 @@ impl DynamicEngine { Some(telemetry_event) = telemetry_rx.recv() => { self.handle_telemetry_event(&telemetry_event); }, + Some(view_data_update) = view_data_rx.recv() => { + self.handle_view_data_update(&view_data_update); + }, else => break, } } @@ -156,6 +167,14 @@ impl DynamicEngine { self.telemetry_subscribers.push(tx); let _ = response_tx.send(rx).await; }, + QueryMessage::SubscribeViewData { response_tx } => { + let (tx, rx) = mpsc::channel(DEFAULT_SUBSCRIBER_CHANNEL_CAPACITY); + self.view_data_subscribers.push(tx); + let _ = response_tx.send(rx).await; + }, + QueryMessage::GetNodeViewData { response_tx } => { + let _ = response_tx.send(self.node_view_data.clone()).await; + }, } } @@ -340,6 +359,29 @@ impl DynamicEngine { }); } + /// Handles a node view data update by storing it and broadcasting to subscribers. + /// + /// View data is best-effort (like stats): dropped updates are acceptable. + fn handle_view_data_update(&mut self, update: &NodeViewDataUpdate) { + // Ignore view data updates for nodes that have been removed + if !self.live_nodes.contains_key(&update.node_id) { + tracing::trace!( + node = %update.node_id, + "Ignoring view data update for removed node" + ); + return; + } + + // Store latest value + self.node_view_data.insert(update.node_id.clone(), update.data.clone()); + + // Broadcast to all subscribers + self.view_data_subscribers.retain(|subscriber| match subscriber.try_send(update.clone()) { + Ok(()) | Err(mpsc::error::TrySendError::Full(_)) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + }); + } + /// Handles a node statistics update by storing it and broadcasting to subscribers /// /// Not async because all operations are synchronous (no .await calls) @@ -425,8 +467,8 @@ impl DynamicEngine { /// Helper function to initialize a node and its I/O actors (Pin Distributors). /// - /// Takes node_id, kind, state_tx, stats_tx, and telemetry_tx by reference since they're cloned - /// multiple times internally (for channels, metrics, etc.) + /// Takes node_id, kind, state_tx, stats_tx, telemetry_tx, and view_data_tx by reference since + /// they're cloned multiple times internally (for channels, metrics, etc.) async fn initialize_node( &mut self, node: Box, @@ -435,6 +477,7 @@ impl DynamicEngine { state_tx: &mpsc::Sender, stats_tx: &mpsc::Sender, telemetry_tx: &mpsc::Sender, + view_data_tx: &mpsc::Sender, ) -> Result<(), StreamKitError> { let mut node = node; @@ -502,6 +545,9 @@ impl DynamicEngine { // 5. Create NodeContext let context = NodeContext { inputs: node_inputs_map, + // Dynamic pipelines wire connections after nodes are spawned, so + // input types are not known at construction time. + input_types: HashMap::new(), control_rx, // We use OutputRouting::Direct, pointing the node directly to its Pin Distributors output_sender: OutputSender::new( @@ -516,6 +562,8 @@ impl DynamicEngine { cancellation_token: None, // Dynamic pipelines don't use cancellation tokens pin_management_rx, audio_pool: Some(self.audio_pool.clone()), + video_pool: Some(self.video_pool.clone()), + view_data_tx: Some(view_data_tx.clone()), }; // 5. Spawn Node @@ -897,6 +945,7 @@ impl DynamicEngine { // 4. Clean up Control Plane state self.node_states.remove(node_id); self.node_stats.remove(node_id); + self.node_view_data.remove(node_id); self.node_pin_metadata.remove(node_id); self.pin_management_txs.remove(node_id); self.node_kinds.remove(node_id); @@ -912,6 +961,7 @@ impl DynamicEngine { state_tx: &mpsc::Sender, stats_tx: &mpsc::Sender, telemetry_tx: &mpsc::Sender, + view_data_tx: &mpsc::Sender, ) -> bool { match msg { EngineControlMessage::AddNode { node_id, kind, params } => { @@ -941,6 +991,7 @@ impl DynamicEngine { state_tx, stats_tx, telemetry_tx, + view_data_tx, ) .await { @@ -1071,6 +1122,7 @@ impl DynamicEngine { } self.node_states.clear(); self.node_stats.clear(); + self.node_view_data.clear(); self.nodes_active_gauge.record(0, &[]); tracing::info!("All nodes shut down successfully"); diff --git a/crates/engine/src/dynamic_handle.rs b/crates/engine/src/dynamic_handle.rs index ab4bdeba..f62d4748 100644 --- a/crates/engine/src/dynamic_handle.rs +++ b/crates/engine/src/dynamic_handle.rs @@ -11,6 +11,7 @@ use streamkit_core::control::EngineControlMessage; use streamkit_core::state::{NodeState, NodeStateUpdate}; use streamkit_core::stats::{NodeStats, NodeStatsUpdate}; use streamkit_core::telemetry::TelemetryEvent; +use streamkit_core::view_data::NodeViewDataUpdate; use tokio::sync::mpsc; /// A handle to communicate with a running dynamic engine actor. @@ -121,6 +122,37 @@ impl DynamicEngineHandle { response_rx.recv().await.ok_or_else(|| "Failed to receive response from engine".to_string()) } + /// Subscribes to node view data updates. + /// Returns a receiver that will receive all subsequent view data updates. + /// + /// # Errors + /// + /// Returns an error if the engine actor has shut down or fails to respond. + pub async fn subscribe_view_data(&self) -> Result, String> { + let (response_tx, mut response_rx) = mpsc::channel(1); + self.query_tx + .send(QueryMessage::SubscribeViewData { response_tx }) + .await + .map_err(|_| "Engine actor has shut down".to_string())?; + + response_rx.recv().await.ok_or_else(|| "Failed to receive response from engine".to_string()) + } + + /// Gets the current view data for all nodes in the pipeline. + /// + /// # Errors + /// + /// Returns an error if the engine actor has shut down or fails to respond. + pub async fn get_node_view_data(&self) -> Result, String> { + let (response_tx, mut response_rx) = mpsc::channel(1); + self.query_tx + .send(QueryMessage::GetNodeViewData { response_tx }) + .await + .map_err(|_| "Engine actor has shut down".to_string())?; + + response_rx.recv().await.ok_or_else(|| "Failed to receive response from engine".to_string()) + } + /// Sends a shutdown signal to the engine and waits for it to complete. /// This ensures all nodes are properly stopped before returning. /// Can only be called once - subsequent calls will return an error. diff --git a/crates/engine/src/dynamic_messages.rs b/crates/engine/src/dynamic_messages.rs index 88666721..99ce078e 100644 --- a/crates/engine/src/dynamic_messages.rs +++ b/crates/engine/src/dynamic_messages.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use streamkit_core::state::{NodeState, NodeStateUpdate}; use streamkit_core::stats::{NodeStats, NodeStatsUpdate}; use streamkit_core::telemetry::TelemetryEvent; +use streamkit_core::view_data::NodeViewDataUpdate; use tokio::sync::mpsc; /// Unique identifier for a connection (FromNode, FromPin, ToNode, ToPin). @@ -51,6 +52,8 @@ pub enum QueryMessage { SubscribeState { response_tx: mpsc::Sender> }, SubscribeStats { response_tx: mpsc::Sender> }, SubscribeTelemetry { response_tx: mpsc::Sender> }, + SubscribeViewData { response_tx: mpsc::Sender> }, + GetNodeViewData { response_tx: mpsc::Sender> }, } // Re-export ConnectionMode from core for use by pin distributor diff --git a/crates/engine/src/dynamic_pin_distributor.rs b/crates/engine/src/dynamic_pin_distributor.rs index 1dd777f6..9ff690c3 100644 --- a/crates/engine/src/dynamic_pin_distributor.rs +++ b/crates/engine/src/dynamic_pin_distributor.rs @@ -435,6 +435,15 @@ impl PinDistributorActor { let dur_s = frame.duration_us().map(|us| us as f64 / 1_000_000.0); (bytes, dur_s) }, + Packet::Video(frame) => { + let bytes = frame.data.len() as f64; + let dur_s = frame + .metadata + .as_ref() + .and_then(|m| m.duration_us) + .map(|us| us as f64 / 1_000_000.0); + (bytes, dur_s) + }, Packet::Binary { data, metadata, .. } => { let bytes = data.len() as f64; let dur_s = diff --git a/crates/engine/src/graph_builder.rs b/crates/engine/src/graph_builder.rs index 75c4cf0a..9ab983f5 100644 --- a/crates/engine/src/graph_builder.rs +++ b/crates/engine/src/graph_builder.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use std::time::{Instant, SystemTime}; use streamkit_core::control::NodeControlMessage; use streamkit_core::error::StreamKitError; -use streamkit_core::frame_pool::AudioFramePool; +use streamkit_core::frame_pool::{AudioFramePool, VideoFramePool}; use streamkit_core::node::{InitContext, NodeContext, OutputRouting, OutputSender, ProcessorNode}; use streamkit_core::packet_meta::{can_connect, packet_type_registry}; use streamkit_core::pins::PinUpdate; @@ -66,6 +66,7 @@ pub async fn wire_and_spawn_graph( stats_tx: Option>, cancellation_token: Option, audio_pool: Option>, + video_pool: Option>, ) -> Result, StreamKitError> { tracing::info!( "Graph builder starting with {} nodes and {} connections", @@ -212,6 +213,9 @@ pub async fn wire_and_spawn_graph( tracing::warn!("Type inference reached maximum iterations (100), some Passthrough types may remain unresolved"); } + // Map from (to_node, to_pin) -> PacketType of the connected upstream output. + let mut input_types: HashMap<(String, String), PacketType> = HashMap::new(); + for conn in connections { tracing::debug!( "Creating connection: {}.{} -> {}.{}", @@ -300,6 +304,10 @@ pub async fn wire_and_spawn_graph( } } + // Record the upstream output type so we can provide it to the + // receiving node via `NodeContext::input_types`. + input_types.insert(to_key.clone(), out_ty.clone()); + output_txs.insert(from_key, tx); input_rxs.insert(to_key, rx); } @@ -320,12 +328,17 @@ pub async fn wire_and_spawn_graph( #[allow(clippy::unwrap_used)] let node = nodes.remove(&name).unwrap(); let mut node_inputs = HashMap::new(); + let mut node_input_types = HashMap::new(); let input_pins = node.input_pins(); tracing::debug!("Node '{}' has {} input pins", name, input_pins.len()); for pin in input_pins { - if let Some(rx) = input_rxs.remove(&(name.clone(), pin.name.clone())) { + let key = (name.clone(), pin.name.clone()); + if let Some(rx) = input_rxs.remove(&key) { tracing::debug!("Connected input pin '{}.{}'", name, pin.name); + if let Some(ty) = input_types.remove(&key) { + node_input_types.insert(pin.name.clone(), ty); + } node_inputs.insert(pin.name, rx); } else { tracing::debug!("Input pin '{}.{}' not connected", name, pin.name); @@ -354,6 +367,7 @@ pub async fn wire_and_spawn_graph( let context = NodeContext { inputs: node_inputs, + input_types: node_input_types, control_rx, output_sender: OutputSender::new(name.clone(), OutputRouting::Direct(direct_outputs)), batch_size, @@ -364,6 +378,8 @@ pub async fn wire_and_spawn_graph( cancellation_token: cancellation_token.clone(), pin_management_rx: None, // Stateless pipelines don't support dynamic pins audio_pool: audio_pool.clone(), + video_pool: video_pool.clone(), + view_data_tx: None, // Stateless pipelines don't emit view data }; tracing::debug!("Starting task for node '{}'", name); diff --git a/crates/engine/src/lib.rs b/crates/engine/src/lib.rs index e5a6fabb..b8f14ddc 100644 --- a/crates/engine/src/lib.rs +++ b/crates/engine/src/lib.rs @@ -53,6 +53,7 @@ use dynamic_actor::DynamicEngine; pub struct Engine { pub registry: Arc>, pub audio_pool: Arc, + pub video_pool: Arc, } impl Default for Engine { fn default() -> Self { @@ -156,6 +157,7 @@ impl Engine { Self { registry: Arc::new(RwLock::new(registry)), audio_pool: Arc::new(streamkit_core::AudioFramePool::audio_default()), + video_pool: Arc::new(streamkit_core::VideoFramePool::video_default()), } } @@ -229,6 +231,7 @@ impl Engine { batch_size: config.packet_batch_size, session_id: config.session_id, audio_pool: self.audio_pool.clone(), + video_pool: self.video_pool.clone(), node_input_capacity, pin_distributor_capacity, node_states: HashMap::new(), @@ -236,6 +239,8 @@ impl Engine { node_stats: HashMap::new(), stats_subscribers: Vec::new(), telemetry_subscribers: Vec::new(), + node_view_data: HashMap::new(), + view_data_subscribers: Vec::new(), nodes_active_gauge: meter .u64_gauge("engine.nodes.active") .with_description("Number of active nodes in the pipeline") diff --git a/crates/engine/src/oneshot.rs b/crates/engine/src/oneshot.rs index bc7f077d..9731393f 100644 --- a/crates/engine/src/oneshot.rs +++ b/crates/engine/src/oneshot.rs @@ -104,7 +104,7 @@ impl Engine { /// /// Panics if the engine's registry lock is poisoned (only possible if a thread panicked /// while holding the lock). - #[allow(clippy::cognitive_complexity)] + #[allow(clippy::cognitive_complexity, clippy::too_many_lines)] pub async fn run_oneshot_pipeline( &self, definition: Pipeline, @@ -167,14 +167,24 @@ impl Engine { http_input_nodes.len(), output_node_id.as_deref().unwrap_or("unknown") ); - } else { - if source_node_ids.is_empty() { - tracing::error!("Pipeline validation failed: no file_reader nodes found"); + } else if !source_node_ids.is_empty() { + if !inputs.is_empty() { + tracing::error!( + "Pipeline validation failed: streams provided but no http_input nodes present" + ); return Err(StreamKitError::Configuration( - "File-based pipelines must contain at least one 'core::file_reader' node." + "Multipart streams were provided but the pipeline has no 'streamkit::http_input' nodes." .to_string(), )); } + tracing::info!( + "File-based mode: {} source node(s), output='{}'", + source_node_ids.len(), + output_node_id.as_deref().unwrap_or("unknown") + ); + } else { + // Generator mode: pipeline produces its own data (e.g. video::colorbars) + // No http_input or file_reader required — just needs http_output. if !inputs.is_empty() { tracing::error!( "Pipeline validation failed: streams provided but no http_input nodes present" @@ -185,8 +195,7 @@ impl Engine { )); } tracing::info!( - "File-based mode: {} source node(s), output='{}'", - source_node_ids.len(), + "Generator mode: no input nodes, output='{}'", output_node_id.as_deref().unwrap_or("unknown") ); } @@ -325,11 +334,29 @@ impl Engine { })?; tracing::debug!("Creating final node instance of type '{}'", final_node_def.kind); - // Get the static content-type from the final node before we move it + // Walk backwards from the output node through the connection graph to find + // the first node that declares a content_type. This allows passthrough-style + // nodes (pacer, passthrough, telemetry_tap, etc.) to be inserted before + // http_output without losing the upstream content type. let static_content_type = { - let temp_instance = - registry.create_node(&final_node_def.kind, final_node_def.params.as_ref())?; - temp_instance.content_type() + let mut cursor = final_node_id.as_str(); + let mut found: Option = None; + // Limit iterations to prevent infinite loops in malformed graphs. + for _ in 0..definition.nodes.len() { + if let Some(def) = definition.nodes.get(cursor) { + let temp = registry.create_node(&def.kind, def.params.as_ref())?; + if let Some(ct) = temp.content_type() { + found = Some(ct); + break; + } + } + // Move to the upstream node that feeds `cursor`. + match definition.connections.iter().find(|c| c.to_node == cursor) { + Some(conn) => cursor = conn.from_node.as_str(), + None => break, + } + } + found }; // --- 4. Instantiate all nodes for the pipeline --- @@ -378,6 +405,7 @@ impl Engine { let node_kinds_for_metrics = node_kinds.clone(); let audio_pool = self.audio_pool.clone(); + let video_pool = self.video_pool.clone(); let (stats_tx, stats_rx) = mpsc::channel(DEFAULT_STATE_CHANNEL_CAPACITY); @@ -391,19 +419,37 @@ impl Engine { Some(stats_tx), Some(cancellation_token.clone()), Some(audio_pool), + Some(video_pool), ) .await?; tracing::info!("Pipeline graph successfully spawned"); spawn_oneshot_metrics_recorder(stats_rx, node_kinds_for_metrics); - // --- 5.5. Start file readers (if any) --- - if !source_node_ids.is_empty() { + // --- 5.5. Start source / generator nodes --- + // File readers need an explicit Start signal, and so do generator nodes + // (e.g. video::colorbars) that follow the Ready → Start lifecycle. + // In generator mode we find root nodes (never a to_node in any connection) + // and send them Start as well. + let mut start_node_ids: Vec = source_node_ids.clone(); + + if !has_http_input && source_node_ids.is_empty() { + // Generator mode — find root nodes that need a Start signal. + let downstream_nodes: std::collections::HashSet<&str> = + definition.connections.iter().map(|c| c.to_node.as_str()).collect(); + for name in definition.nodes.keys() { + if name != &output_node_id && !downstream_nodes.contains(name.as_str()) { + start_node_ids.push(name.clone()); + } + } + } + + if !start_node_ids.is_empty() { tracing::info!( - "Sending Start signals to {} file_reader node(s)", - source_node_ids.len() + "Sending Start signals to {} source/generator node(s)", + start_node_ids.len() ); - for source_id in &source_node_ids { + for source_id in &start_node_ids { if let Some(node_handle) = live_nodes.get(source_id) { tracing::debug!("Sending Start signal to source node '{}'", source_id); if let Err(e) = node_handle.control_tx.send(NodeControlMessage::Start).await { diff --git a/crates/engine/src/tests/connection_types.rs b/crates/engine/src/tests/connection_types.rs index 72076f52..ce5d4da5 100644 --- a/crates/engine/src/tests/connection_types.rs +++ b/crates/engine/src/tests/connection_types.rs @@ -7,7 +7,9 @@ use super::super::*; use crate::dynamic_actor::{DynamicEngine, NodePinMetadata}; use streamkit_core::registry::NodeRegistry; -use streamkit_core::types::{AudioFormat, PacketType, SampleFormat}; +use streamkit_core::types::{ + AudioCodec, AudioFormat, EncodedAudioFormat, PacketType, SampleFormat, +}; use streamkit_core::{InputPin, OutputPin, PinCardinality}; use tokio::sync::mpsc; @@ -33,6 +35,7 @@ fn create_test_engine() -> DynamicEngine { batch_size: 32, session_id: None, audio_pool: std::sync::Arc::new(streamkit_core::FramePool::::audio_default()), + video_pool: std::sync::Arc::new(streamkit_core::FramePool::::video_default()), node_input_capacity: 128, pin_distributor_capacity: 64, node_states: HashMap::new(), @@ -40,6 +43,8 @@ fn create_test_engine() -> DynamicEngine { node_stats: HashMap::new(), stats_subscribers: Vec::new(), telemetry_subscribers: Vec::new(), + node_view_data: HashMap::new(), + view_data_subscribers: Vec::new(), nodes_active_gauge: meter.u64_gauge("test.nodes").build(), node_state_transitions_counter: meter.u64_counter("test.transitions").build(), engine_operations_counter: meter.u64_counter("test.operations").build(), @@ -97,14 +102,17 @@ fn test_validate_connection_types_incompatible() { let audio_format = AudioFormat { sample_rate: 48000, channels: 2, sample_format: SampleFormat::F32 }; - // Create source node with OpusAudio output + // Create source node with encoded Opus output engine.node_pin_metadata.insert( "source".to_string(), NodePinMetadata { input_pins: vec![], output_pins: vec![OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }], }, @@ -175,14 +183,17 @@ fn test_validate_connection_types_passthrough_source() { fn test_validate_connection_types_any_destination() { let mut engine = create_test_engine(); - // Create source node with OpusAudio output + // Create source node with encoded Opus output engine.node_pin_metadata.insert( "source".to_string(), NodePinMetadata { input_pins: vec![], output_pins: vec![OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }], }, @@ -229,7 +240,10 @@ fn test_validate_connection_types_pin_not_found() { input_pins: vec![], output_pins: vec![OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }], }, @@ -241,7 +255,10 @@ fn test_validate_connection_types_pin_not_found() { NodePinMetadata { input_pins: vec![InputPin { name: "in".to_string(), - accepts_types: vec![PacketType::OpusAudio], + accepts_types: vec![PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + })], cardinality: PinCardinality::One, }], output_pins: vec![], diff --git a/crates/engine/src/tests/dynamic_initialize.rs b/crates/engine/src/tests/dynamic_initialize.rs index 1ae5ca4c..426b0a6a 100644 --- a/crates/engine/src/tests/dynamic_initialize.rs +++ b/crates/engine/src/tests/dynamic_initialize.rs @@ -67,6 +67,7 @@ async fn test_dynamic_engine_calls_initialize() { let engine = Engine { registry: Arc::new(std::sync::RwLock::new(registry)), audio_pool: Arc::new(streamkit_core::AudioFramePool::audio_default()), + video_pool: Arc::new(streamkit_core::VideoFramePool::video_default()), }; let handle = engine.start_dynamic_actor(DynamicEngineConfig::default()); diff --git a/crates/engine/src/tests/oneshot_linear.rs b/crates/engine/src/tests/oneshot_linear.rs index cf2028f3..02472c53 100644 --- a/crates/engine/src/tests/oneshot_linear.rs +++ b/crates/engine/src/tests/oneshot_linear.rs @@ -73,6 +73,7 @@ async fn test_oneshot_rejects_fanout() { None, None, None, + None, ) .await else { diff --git a/crates/nodes/Cargo.toml b/crates/nodes/Cargo.toml index c6eb5944..e27a8502 100644 --- a/crates/nodes/Cargo.toml +++ b/crates/nodes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "streamkit-nodes" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Claudio Costa ", "StreamKit Contributors"] description = "Built-in processing nodes for StreamKit" @@ -46,13 +46,11 @@ url = { version = "2.5.8", optional = true, features = ["serde"] } rquickjs = { version = "0.11", features = ["array-buffer", "futures", "loader", "parallel"], optional = true } wildmatch = { version = "2.6", optional = true } -moq-transport = { version = "0.12.2", optional = true } -moq-native = { version = "0.12.1", optional = true } -moq-lite = { version = "0.13.0", optional = true } -hang = { version = "0.13.0", optional = true } +moq-native = { version = "0.13.2", optional = true } +moq-lite = { version = "0.15.0", optional = true } +hang = { version = "0.15.1", optional = true } # For local dev, debugging moq stuff -# moq-transport = { version = "0.11.0", optional = true } # moq-native = { path = "../../moq/rs/moq-native", optional = true } # moq-lite = { path = "../../moq/rs/moq", optional = true } # hang = { path = "../../moq/rs/hang", optional = true } @@ -74,6 +72,13 @@ symphonia = { version = "0.5.5", optional = true, default-features = false, feat "ogg", ] } webm = { version = "2.2.0", optional = true } +env-libvpx-sys = { version = "5.1.3", optional = true } +image = { version = "0.25.9", optional = true, default-features = false, features = ["png", "jpeg"] } +tiny-skia = { version = "0.12.0", optional = true } +base64 = { version = "0.22", optional = true } +rayon = { version = "1.10", optional = true } +fontdue = { version = "0.9", optional = true } +smallvec = { version = "1.13", optional = true, features = ["serde"] } futures-util = "0.3" @@ -93,6 +98,7 @@ default = [ "http", "symphonia", "script", + "video", ] # Individual features for each node. @@ -107,7 +113,6 @@ http = ["dep:schemars", "dep:reqwest", "dep:tempfile"] script = ["dep:rquickjs", "dep:reqwest", "dep:wildmatch", "dep:schemars", "dep:serde_json", "dep:url", "dep:uuid"] moq = [ "dep:schemars", - "dep:moq-transport", "dep:url", "dep:moq-native", "dep:moq-lite", @@ -119,13 +124,37 @@ moq = [ # The `dep:` syntax enables the optional dependency when the feature is active. opus = ["dep:opus", "dep:schemars"] ogg = ["dep:ogg", "dep:schemars"] -webm = ["dep:webm", "dep:schemars"] +webm = ["dep:webm", "dep:schemars", "dep:tempfile"] symphonia = ["dep:symphonia", "dep:schemars"] +vp9 = ["dep:env-libvpx-sys", "dep:schemars"] +colorbars = ["dep:schemars", "dep:serde_json", "dep:fontdue"] +compositor = ["dep:schemars", "dep:serde_json", "dep:image", "dep:tiny-skia", "dep:base64", "dep:rayon", "dep:fontdue", "dep:smallvec"] +video = ["vp9", "colorbars", "compositor"] [dev-dependencies] tempfile = "3.24" axum = "0.8" tower = "0.5.3" -[lints] -workspace = true +[lints.rust] +# Override workspace lint: VP9 nodes require unsafe for libvpx FFI. +# All unsafe blocks are documented with safety comments. +unsafe_code = "allow" + +[lints.clippy] +# Categories (from workspace) +pedantic = { level = "warn", priority = -1 } +nursery = { level = "warn", priority = -1 } +# Safety +unwrap_used = "warn" +expect_used = "warn" +# Complexity +cognitive_complexity = "warn" +# Math +cast_possible_truncation = "warn" +cast_precision_loss = "warn" +cast_sign_loss = "warn" +# Allow-list (Noise reduction) +module_name_repetitions = "allow" +must_use_candidate = "allow" +doc_markdown = "allow" diff --git a/crates/nodes/src/audio/codecs/mp3.rs b/crates/nodes/src/audio/codecs/mp3.rs index 346fe4ea..5c1e491d 100644 --- a/crates/nodes/src/audio/codecs/mp3.rs +++ b/crates/nodes/src/audio/codecs/mp3.rs @@ -346,6 +346,7 @@ fn decode_mp3_streaming_incremental( timestamp_us: Some(cumulative_timestamp_us), duration_us: Some(duration_us), sequence: Some(frame_count), + keyframe: None, }; // Use blocking_send - more efficient than Handle::block_on @@ -382,6 +383,7 @@ fn decode_mp3_streaming_incremental( timestamp_us: Some(cumulative_timestamp_us), duration_us: Some(duration_us), sequence: Some(frame_count), + keyframe: None, }; let final_chunk: Vec = rechunk_buffer.into_iter().collect(); @@ -502,6 +504,7 @@ fn decode_mp3_streaming(data: &[u8], result_tx: &mpsc::Sender) -> timestamp_us: Some(cumulative_timestamp_us), duration_us: Some(duration_us), sequence: Some(packet_count), + keyframe: None, }; if frame_tx.send((chunk, sample_rate, channels, metadata)).is_err() { @@ -532,6 +535,7 @@ fn decode_mp3_streaming(data: &[u8], result_tx: &mpsc::Sender) -> timestamp_us: Some(cumulative_timestamp_us), duration_us: Some(duration_us), sequence: Some(packet_count), + keyframe: None, }; let final_chunk: Vec = rechunk_buffer.into_iter().collect(); diff --git a/crates/nodes/src/audio/codecs/opus.rs b/crates/nodes/src/audio/codecs/opus.rs index fed47acd..fe53712d 100644 --- a/crates/nodes/src/audio/codecs/opus.rs +++ b/crates/nodes/src/audio/codecs/opus.rs @@ -10,7 +10,9 @@ use serde::Deserialize; use std::sync::Arc; use std::time::Instant; use streamkit_core::stats::NodeStatsTracker; -use streamkit_core::types::{AudioFormat, AudioFrame, Packet, PacketType, SampleFormat}; +use streamkit_core::types::{ + AudioCodec, AudioFormat, AudioFrame, EncodedAudioFormat, Packet, PacketType, SampleFormat, +}; use streamkit_core::{ get_codec_channel_capacity, packet_helpers, state_helpers, AudioFramePool, InputPin, NodeContext, NodeRegistry, OutputPin, PinCardinality, PooledSamples, ProcessorNode, @@ -57,7 +59,10 @@ impl ProcessorNode for OpusDecoderNode { fn input_pins(&self) -> Vec { vec![InputPin { name: "in".to_string(), - accepts_types: vec![PacketType::OpusAudio], + accepts_types: vec![PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + })], cardinality: PinCardinality::One, }] } @@ -419,7 +424,10 @@ impl ProcessorNode for OpusEncoderNode { fn output_pins(&self) -> Vec { vec![OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }] } @@ -598,6 +606,7 @@ impl ProcessorNode for OpusEncoderNode { timestamp_us: None, // No absolute timestamp duration_us: Some(duration_us), sequence: None, // No sequence tracking yet + keyframe: None, }), }; if context diff --git a/crates/nodes/src/audio/filters/mixer.rs b/crates/nodes/src/audio/filters/mixer.rs index 355fc02e..a8895afe 100644 --- a/crates/nodes/src/audio/filters/mixer.rs +++ b/crates/nodes/src/audio/filters/mixer.rs @@ -951,7 +951,7 @@ impl AudioMixerNode { mix_frames.iter().filter_map(|f| f.metadata.as_ref().and_then(|m| m.sequence)).max(); let combined_metadata = if timestamp_us.is_some() || duration_us.is_some() || sequence.is_some() { - Some(PacketMetadata { timestamp_us, duration_us, sequence }) + Some(PacketMetadata { timestamp_us, duration_us, sequence, keyframe: None }) } else { None }; @@ -1431,6 +1431,7 @@ fn run_clocked_audio_thread(config: &ClockedThreadConfig) { timestamp_us: None, duration_us: Some(tick_us), sequence: None, + keyframe: None, })); let output_frame = mix_clocked_frames( @@ -2137,6 +2138,7 @@ mod tests { timestamp_us: Some(1), duration_us: Some(20_000), sequence: Some(1), + keyframe: None, }), ); let frame_b = AudioFrame::with_metadata( @@ -2147,6 +2149,7 @@ mod tests { timestamp_us: Some(2), duration_us: Some(40_000), sequence: Some(2), + keyframe: None, }), ); diff --git a/crates/nodes/src/audio/filters/resampler.rs b/crates/nodes/src/audio/filters/resampler.rs index 477d5030..d6f92684 100644 --- a/crates/nodes/src/audio/filters/resampler.rs +++ b/crates/nodes/src/audio/filters/resampler.rs @@ -283,6 +283,7 @@ impl ProcessorNode for AudioResamplerNode { timestamp_us: output_timestamp_us, duration_us: Some(duration_us), sequence: Some(output_sequence), + keyframe: None, }; output_sequence += 1; if let Some(ts) = output_timestamp_us.as_mut() { @@ -588,6 +589,7 @@ impl ProcessorNode for AudioResamplerNode { timestamp_us: output_timestamp_us, duration_us: Some(duration_us), sequence: Some(output_sequence), + keyframe: None, }; output_sequence += 1; if let Some(ts) = output_timestamp_us.as_mut() { @@ -634,6 +636,7 @@ impl ProcessorNode for AudioResamplerNode { timestamp_us: output_timestamp_us, duration_us: Some(duration_us), sequence: Some(output_sequence), + keyframe: None, }; output_sequence += 1; if let Some(ts) = output_timestamp_us.as_mut() { @@ -678,6 +681,7 @@ impl ProcessorNode for AudioResamplerNode { timestamp_us: output_timestamp_us, duration_us: Some(duration_us), sequence: Some(output_sequence), + keyframe: None, }; if let Some(ts) = output_timestamp_us.as_mut() { *ts += duration_us; @@ -757,6 +761,7 @@ mod tests { let context = NodeContext { inputs, + input_types: HashMap::new(), control_rx, output_sender, batch_size: 32, @@ -767,6 +772,8 @@ mod tests { cancellation_token: None, pin_management_rx: None, // Test contexts don't support dynamic pins audio_pool: None, + video_pool: None, + view_data_tx: None, }; // Create node that downsamples from 48kHz to 24kHz @@ -833,6 +840,7 @@ mod tests { let context = NodeContext { inputs, + input_types: HashMap::new(), control_rx, output_sender, batch_size: 32, @@ -843,6 +851,8 @@ mod tests { cancellation_token: None, pin_management_rx: None, // Test contexts don't support dynamic pins audio_pool: None, + video_pool: None, + view_data_tx: None, }; let config = AudioResamplerConfig { diff --git a/crates/nodes/src/containers/ogg.rs b/crates/nodes/src/containers/ogg.rs index 8c3ed2d9..5364c97d 100644 --- a/crates/nodes/src/containers/ogg.rs +++ b/crates/nodes/src/containers/ogg.rs @@ -13,7 +13,7 @@ use std::borrow::Cow; use std::io::Write; use std::sync::{Arc, Mutex}; use streamkit_core::stats::NodeStatsTracker; -use streamkit_core::types::{Packet, PacketType}; +use streamkit_core::types::{AudioCodec, EncodedAudioFormat, Packet, PacketType}; use streamkit_core::{ get_demuxer_buffer_size, get_stream_channel_capacity, state_helpers, InputPin, NodeContext, NodeRegistry, OutputPin, PinCardinality, ProcessorNode, StreamKitError, @@ -101,7 +101,10 @@ impl ProcessorNode for OggMuxerNode { fn input_pins(&self) -> Vec { vec![InputPin { name: "in".to_string(), - accepts_types: vec![PacketType::OpusAudio], // Accepts Opus for now + accepts_types: vec![PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + })], // Accepts Opus for now cardinality: PinCardinality::One, }] } @@ -384,7 +387,10 @@ impl ProcessorNode for OggDemuxerNode { fn output_pins(&self) -> Vec { vec![OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }] } @@ -539,6 +545,7 @@ impl ProcessorNode for OggDemuxerNode { timestamp_us: Some(timestamp_us), duration_us, sequence: Some(packets_extracted), + keyframe: None, }) } else { // No valid granule position (header packets) @@ -633,7 +640,10 @@ impl ProcessorNode for SymphoniaOggDemuxerNode { fn output_pins(&self) -> Vec { vec![OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }] } @@ -723,6 +733,7 @@ impl ProcessorNode for SymphoniaOggDemuxerNode { timestamp_us: Some(timestamp_us), duration_us: Some(duration_us), sequence: Some(packets_extracted), + keyframe: None, }) } else { None diff --git a/crates/nodes/src/containers/tests.rs b/crates/nodes/src/containers/tests.rs index 9ef86ade..aa397683 100644 --- a/crates/nodes/src/containers/tests.rs +++ b/crates/nodes/src/containers/tests.rs @@ -16,7 +16,7 @@ use bytes::Bytes; use std::collections::HashMap; use std::path::Path; use streamkit_core::node::ProcessorNode; -use streamkit_core::types::Packet; +use streamkit_core::types::{AudioCodec, EncodedAudioFormat, Packet, PacketType}; use tokio::sync::mpsc; /// Helper to read test audio files @@ -290,7 +290,14 @@ async fn test_webm_muxer_basic() { let mut inputs = HashMap::new(); inputs.insert("in".to_string(), input_rx); - let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + let (mut context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + context.input_types.insert( + "in".to_string(), + PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), + ); // Create WebM muxer node let config = WebMMuxerConfig::default(); @@ -348,7 +355,14 @@ async fn test_webm_muxer_multiple_packets() { let mut inputs = HashMap::new(); inputs.insert("in".to_string(), input_rx); - let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + let (mut context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + context.input_types.insert( + "in".to_string(), + PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), + ); let config = WebMMuxerConfig::default(); let node = WebMMuxerNode::new(config); @@ -396,7 +410,14 @@ async fn test_webm_sliding_window() { let mut inputs = HashMap::new(); inputs.insert("in".to_string(), input_rx); - let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + let (mut context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + context.input_types.insert( + "in".to_string(), + PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), + ); // Create config with smaller chunk size for testing let config = WebMMuxerConfig { @@ -430,3 +451,134 @@ async fn test_webm_sliding_window() { output_packets.len() ); } + +/// Smoke test: video-only VP9 frames muxed into WebM produce non-empty, parseable output. +#[cfg(feature = "vp9")] +#[tokio::test] +async fn test_webm_mux_vp9_video_only() { + use crate::test_utils::create_test_video_frame; + use crate::video::vp9::{Vp9EncoderConfig, Vp9EncoderNode}; + use streamkit_core::types::{PacketMetadata, PixelFormat}; + + // ---- Step 1: Encode some raw I420 frames to VP9 ---- + + let (enc_input_tx, enc_input_rx) = mpsc::channel(10); + let mut enc_inputs = HashMap::new(); + enc_inputs.insert("in".to_string(), enc_input_rx); + + let (enc_context, enc_sender, mut enc_state_rx) = create_test_context(enc_inputs, 10); + let encoder_config = Vp9EncoderConfig { + keyframe_interval: 1, + bitrate_kbps: 800, + threads: 1, + ..Default::default() + }; + let encoder = match Vp9EncoderNode::new(encoder_config) { + Ok(enc) => enc, + Err(e) => { + eprintln!("Skipping VP9 video-only mux test: encoder not available ({e})"); + return; + }, + }; + let enc_handle = tokio::spawn(async move { Box::new(encoder).run(enc_context).await }); + + assert_state_initializing(&mut enc_state_rx).await; + assert_state_running(&mut enc_state_rx).await; + + let frame_count = 5u64; + for i in 0..frame_count { + let mut frame = create_test_video_frame(64, 64, PixelFormat::I420, 16); + frame.metadata = Some(PacketMetadata { + timestamp_us: Some(i * 33_333), + duration_us: Some(33_333), + sequence: Some(i), + keyframe: Some(true), + }); + enc_input_tx.send(Packet::Video(frame)).await.unwrap(); + } + drop(enc_input_tx); + + assert_state_stopped(&mut enc_state_rx).await; + enc_handle.await.unwrap().unwrap(); + + let encoded_packets = enc_sender.get_packets_for_pin("out").await; + assert!(!encoded_packets.is_empty(), "VP9 encoder produced no packets"); + + // ---- Step 2: Mux the encoded VP9 packets into WebM ---- + + let (mux_video_tx, mux_video_rx) = mpsc::channel(10); + let mut mux_inputs = HashMap::new(); + // Only video, no audio + mux_inputs.insert("in".to_string(), mux_video_rx); + + let (mut mux_context, mux_sender, mut mux_state_rx) = create_test_context(mux_inputs, 10); + mux_context.input_types.insert( + "in".to_string(), + PacketType::EncodedVideo(streamkit_core::types::EncodedVideoFormat { + codec: streamkit_core::types::VideoCodec::Vp9, + bitstream_format: None, + codec_private: None, + profile: None, + level: None, + }), + ); + let mux_config = + WebMMuxerConfig { video_width: 64, video_height: 64, ..WebMMuxerConfig::default() }; + let muxer = WebMMuxerNode::new(mux_config); + let mux_handle = tokio::spawn(async move { Box::new(muxer).run(mux_context).await }); + + assert_state_initializing(&mut mux_state_rx).await; + assert_state_running(&mut mux_state_rx).await; + + for packet in encoded_packets { + mux_video_tx.send(packet).await.unwrap(); + } + drop(mux_video_tx); + + assert_state_stopped(&mut mux_state_rx).await; + mux_handle.await.unwrap().unwrap(); + + // ---- Step 3: Validate output ---- + + let output_packets = mux_sender.get_packets_for_pin("out").await; + assert!(!output_packets.is_empty(), "WebM muxer produced no output"); + + // Collect all output bytes + let mut webm_bytes = Vec::new(); + for packet in &output_packets { + if let Packet::Binary { data, .. } = packet { + webm_bytes.extend_from_slice(data); + } + } + + assert!(!webm_bytes.is_empty(), "WebM output is empty"); + // WebM/EBML files start with 0x1A45DFA3 (EBML header element ID) + assert!(webm_bytes.len() >= 4, "WebM output too small: {} bytes", webm_bytes.len()); + assert_eq!( + &webm_bytes[..4], + &[0x1A, 0x45, 0xDF, 0xA3], + "WebM output does not start with EBML header" + ); + + // Verify content type + if let Packet::Binary { content_type, .. } = &output_packets[0] { + let ct = content_type.as_ref().expect("content_type should be set"); + assert_eq!(ct.as_ref(), "video/webm; codecs=\"vp9\""); + } + + println!( + "✅ WebM video-only mux test passed: {} output packets, {} total bytes", + output_packets.len(), + webm_bytes.len() + ); +} + +/// Test that muxer returns an error if no inputs are connected. +#[tokio::test] +async fn test_webm_mux_no_inputs_fails() { + let mux_inputs = HashMap::new(); + let (mux_context, _mux_sender, _mux_state_rx) = create_test_context(mux_inputs, 10); + let muxer = WebMMuxerNode::new(WebMMuxerConfig::default()); + let result = Box::new(muxer).run(mux_context).await; + assert!(result.is_err(), "Expected error when no inputs are connected"); +} diff --git a/crates/nodes/src/containers/webm.rs b/crates/nodes/src/containers/webm.rs index 2f2d1f3b..d627282a 100644 --- a/crates/nodes/src/containers/webm.rs +++ b/crates/nodes/src/containers/webm.rs @@ -7,22 +7,129 @@ use bytes::Bytes; use schemars::JsonSchema; use serde::Deserialize; use std::borrow::Cow; -use std::io::{Cursor, Seek, SeekFrom, Write}; +use std::io::{BufWriter, Cursor, Read as _, Seek, SeekFrom, Write}; use std::sync::{Arc, Mutex}; use streamkit_core::stats::NodeStatsTracker; -use streamkit_core::types::{Packet, PacketMetadata, PacketType}; +use streamkit_core::types::{ + AudioCodec, EncodedAudioFormat, EncodedVideoFormat, Packet, PacketMetadata, PacketType, + VideoCodec, +}; use streamkit_core::{ state_helpers, timing::MediaClock, InputPin, NodeContext, NodeRegistry, OutputPin, PinCardinality, ProcessorNode, StreamKitError, }; -use webm::mux::{AudioCodecId, SegmentBuilder, SegmentMode, Writer}; +use webm::mux::{ + AudioCodecId, AudioTrack, SegmentBuilder, SegmentMode, VideoCodecId, VideoTrack, Writer, +}; // --- WebM Constants --- /// Default chunk size for flushing buffers const DEFAULT_CHUNK_SIZE: usize = 65536; -/// Default frame duration when metadata is missing (20ms Opus frame). +/// Default audio frame duration when metadata is missing (20ms Opus frame). const DEFAULT_FRAME_DURATION_US: u64 = 20_000; +use crate::video::DEFAULT_VIDEO_FRAME_DURATION_US; + +// --------------------------------------------------------------------------- +// VP9 keyframe dimension parser +// --------------------------------------------------------------------------- + +/// Minimal bit reader for parsing VP9 uncompressed headers (MSB-first). +struct BitReader<'a> { + data: &'a [u8], + byte_offset: usize, + bit_offset: u8, +} + +impl<'a> BitReader<'a> { + const fn new(data: &'a [u8]) -> Self { + Self { data, byte_offset: 0, bit_offset: 0 } + } + + /// Read `n` bits (1..=16) as a `u32`, MSB first. + fn read(&mut self, n: u8) -> Option { + let mut value: u32 = 0; + for _ in 0..n { + if self.byte_offset >= self.data.len() { + return None; + } + let bit = (self.data[self.byte_offset] >> (7 - self.bit_offset)) & 1; + value = (value << 1) | u32::from(bit); + self.bit_offset += 1; + if self.bit_offset == 8 { + self.bit_offset = 0; + self.byte_offset += 1; + } + } + Some(value) + } +} + +/// Parse the display dimensions from a VP9 keyframe's uncompressed header. +/// +/// Returns `Some((width, height))` when the data starts with a valid VP9 +/// keyframe (profile 0–3). Returns `None` for non-keyframes, truncated +/// data, or invalid sync codes. +fn parse_vp9_keyframe_dimensions(data: &[u8]) -> Option<(u32, u32)> { + if data.len() < 10 { + return None; + } + + let mut r = BitReader::new(data); + + // frame_marker (2 bits) – must be 0b10 + if r.read(2)? != 2 { + return None; + } + + let profile_low = r.read(1)?; + let profile_high = r.read(1)?; + let profile = (profile_high << 1) | profile_low; + + if profile > 2 { + r.read(1)?; // reserved_zero + } + + // show_existing_frame + if r.read(1)? != 0 { + return None; + } + + // frame_type: 0 = KEY_FRAME + if r.read(1)? != 0 { + return None; + } + + r.read(1)?; // show_frame + r.read(1)?; // error_resilient_mode + + // frame_sync_code must be 0x49_83_42 + if r.read(8)? != 0x49 || r.read(8)? != 0x83 || r.read(8)? != 0x42 { + return None; + } + + // color_config + if profile >= 2 { + r.read(1)?; // ten_or_twelve_bit + } + let color_space = r.read(3)?; + if color_space != 7 { + // not CS_RGB + r.read(1)?; // color_range + if profile == 1 || profile == 3 { + r.read(1)?; // subsampling_x + r.read(1)?; // subsampling_y + r.read(1)?; // reserved + } + } else if profile == 1 || profile == 3 { + r.read(1)?; // reserved + } + + // frame_size: width_minus_1 (16 bits), height_minus_1 (16 bits) + let w = r.read(16)? + 1; + let h = r.read(16)? + 1; + Some((w, h)) +} /// Opus codec lookahead at 48kHz in samples (typical libopus default). /// /// This is written to the OpusHead `pre_skip` field so decoders can trim encoder delay. @@ -59,120 +166,68 @@ fn opus_head_codec_private(sample_rate: u32, channels: u32) -> Result<[u8; 19], // --- WebM Muxer --- +/// Internal state for [`SharedPacketBuffer`], protected by a single mutex +/// to eliminate lock-ordering concerns between cursor, position tracking, +/// and offset bookkeeping. +struct BufferState { + cursor: Cursor>, + last_sent_pos: usize, + base_offset: usize, +} + /// A shared, thread-safe buffer that wraps a Cursor for WebM writing. /// This allows us to stream out data as it's written while still supporting Seek. /// -/// Supports two buffering modes: -/// -/// - **Streaming (non-seek)**: Bytes are drained on every `take_data()` call. -/// This mode is intended for `Writer::new_non_seek` and avoids copying. -/// - **Seek window**: Keeps a configurable window of recent data for WebM library seeks -/// and trims old data that has already been sent. -/// -/// The node selects the appropriate mode based on `WebMStreamingMode`. +/// Used for **Live** (streaming) mode only. Bytes are drained on every +/// `take_data()` call so that memory stays bounded. #[derive(Clone)] struct SharedPacketBuffer { - cursor: Arc>>>, - last_sent_pos: Arc>, - base_offset: Arc>, - window_size: usize, + state: Arc>, } impl SharedPacketBuffer { - /// Create a new buffer with a sliding window size. - /// window_size: Maximum bytes to keep in memory (default 1MB for ~6 seconds at 128kbps) - fn new_with_window(window_size: usize) -> Self { - Self { - cursor: Arc::new(Mutex::new(Cursor::new(Vec::new()))), - last_sent_pos: Arc::new(Mutex::new(0)), - base_offset: Arc::new(Mutex::new(0)), - window_size, - } - } - /// Create a non-seek streaming buffer. /// /// This is designed for `Writer::new_non_seek` in live streaming mode. Since the writer /// does not seek/backpatch, we can drain bytes out by moving the underlying `Vec` /// (no copy) and reset the cursor to keep memory bounded. fn new_streaming() -> Self { - // window_size=0 is treated as "drain everything on take_data" - Self::new_with_window(0) + Self { + state: Arc::new(Mutex::new(BufferState { + cursor: Cursor::new(Vec::new()), + last_sent_pos: 0, + base_offset: 0, + })), + } } - /// Takes any new data written since the last call, and trims old data beyond the window. - /// This allows the WebM library to seek backwards within the window while preventing - /// unbounded memory growth for long streams. + /// Takes any new data written since the last call. + /// + /// Streaming mode (non-seek): drain everything written so far without copying. fn take_data(&self) -> Option { // Mutex poisoning is a fatal error - allows expect() for this common pattern #[allow(clippy::expect_used)] - let mut buffer_guard = self.cursor.lock().expect("SharedPacketBuffer mutex poisoned"); - let vec = buffer_guard.get_mut(); + let mut state = self.state.lock().expect("SharedPacketBuffer mutex poisoned"); - #[allow(clippy::expect_used)] - let mut last_sent_guard = self.last_sent_pos.lock().expect("last_sent_pos mutex poisoned"); - #[allow(clippy::expect_used)] - let mut base_offset_guard = self.base_offset.lock().expect("base_offset mutex poisoned"); - - let last_sent = *last_sent_guard; - let current_len = vec.len(); - let base = *base_offset_guard; - - let result = if current_len > last_sent { - if self.window_size == 0 { - // Streaming mode (non-seek): drain everything written so far without copying. - // - // This avoids two major sources of allocation churn in DHAT profiles: - // - copying out incremental slices on every flush - // - repeatedly trimming a sliding window with `split_off` (copies the window) - let data_vec = std::mem::take(vec); - // Advance base_offset so Seek::Start can clamp consistently if it ever happens. - *base_offset_guard = base + current_len; - *last_sent_guard = 0; - buffer_guard.set_position(0); - Some(Bytes::from(data_vec)) - } else if self.window_size == usize::MAX && last_sent == 0 { - // File mode: nothing has been sent yet, so move the entire buffer out. - // The segment is finalized before this is called, so no more writes/seeks occur. - let data_vec = std::mem::take(vec); - *base_offset_guard = base + current_len; - *last_sent_guard = 0; - buffer_guard.set_position(0); - Some(Bytes::from(data_vec)) - } else { - // Seek-window mode: copy incremental bytes while retaining a backwards-seek window. - let new_data = Bytes::copy_from_slice(&vec[last_sent..current_len]); - *last_sent_guard = current_len; - - // Trim old data if buffer exceeds window size. - if current_len > self.window_size { - let trim_amount = current_len - self.window_size; - // Keep the last window_size bytes. - let remaining = vec.split_off(trim_amount); - *vec = remaining; - // Update base offset to reflect discarded data. - *base_offset_guard = base + trim_amount; - // Adjust last_sent and cursor position. - *last_sent_guard = self.window_size; - buffer_guard.set_position(self.window_size as u64); - - tracing::debug!( - "Trimmed {} bytes from WebM buffer, new base_offset: {}", - trim_amount, - *base_offset_guard - ); - } + let base = state.base_offset; + let current_len = state.cursor.get_ref().len(); - Some(new_data) - } - } else { - None - }; + if current_len == 0 { + return None; + } - drop(base_offset_guard); - drop(last_sent_guard); - drop(buffer_guard); - result + // Drain everything written so far without copying. + // + // This avoids two major sources of allocation churn in DHAT profiles: + // - copying out incremental slices on every flush + // - repeatedly trimming a sliding window with `split_off` (copies the window) + let data_vec = std::mem::take(state.cursor.get_mut()); + // Advance base_offset so Seek::Start can clamp consistently if it ever happens. + state.base_offset = base + current_len; + state.last_sent_pos = 0; + state.cursor.set_position(0); + drop(state); + Some(Bytes::from(data_vec)) } } @@ -180,54 +235,146 @@ impl Write for SharedPacketBuffer { fn write(&mut self, buf: &[u8]) -> std::io::Result { // Mutex poisoning is a fatal error - allows expect() for this common pattern #[allow(clippy::expect_used)] - self.cursor.lock().expect("SharedPacketBuffer mutex poisoned").write(buf) + self.state.lock().expect("SharedPacketBuffer mutex poisoned").cursor.write(buf) } fn flush(&mut self) -> std::io::Result<()> { // Mutex poisoning is a fatal error - allows expect() for this common pattern #[allow(clippy::expect_used)] - self.cursor.lock().expect("SharedPacketBuffer mutex poisoned").flush() + self.state.lock().expect("SharedPacketBuffer mutex poisoned").cursor.flush() + } +} + +/// A file-backed buffer for **File** mode WebM muxing. +/// +/// Instead of accumulating the entire muxed output in memory (which violates +/// the "never keep entire files in memory" principle), all writes go to an +/// anonymous temporary file on disk. The temp file supports full seek so +/// libwebm can back-patch segment sizes and cues as needed. +/// +/// At finalization the file contents are read back into a `Bytes` for the +/// single downstream send. This is a one-time, bounded operation — the +/// file is deleted automatically when the struct is dropped. +struct FileBackedBuffer { + inner: BufWriter, +} + +impl FileBackedBuffer { + /// Create a new file-backed buffer using an anonymous temp file. + fn new() -> std::io::Result { + let file = tempfile::tempfile()?; + Ok(Self { inner: BufWriter::new(file) }) + } + + /// Read the entire temp file contents as `Bytes`. + /// + /// This should only be called **once** after `segment.finalize()` — all + /// writes and seeks are complete at that point. + fn take_data(&mut self) -> std::io::Result> { + self.inner.flush()?; + let file = self.inner.get_mut(); + let len = file.stream_position()?; + if len == 0 { + return Ok(None); + } + file.seek(SeekFrom::Start(0))?; + let len_usize = usize::try_from(len).map_err(std::io::Error::other)?; + let mut buf = vec![0u8; len_usize]; + file.read_exact(&mut buf)?; + Ok(Some(Bytes::from(buf))) } } -impl Seek for SharedPacketBuffer { +impl Write for FileBackedBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.inner.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.inner.flush() + } +} + +impl Seek for FileBackedBuffer { fn seek(&mut self, pos: SeekFrom) -> std::io::Result { - // When seeking, we need to adjust for the base_offset since we may have - // trimmed old data from the beginning of the buffer - #[allow(clippy::expect_used)] - let base_guard = self.base_offset.lock().expect("base_offset mutex poisoned"); - let base = *base_guard; - drop(base_guard); + self.inner.seek(pos) + } +} - #[allow(clippy::expect_used)] - let mut cursor_guard = self.cursor.lock().expect("SharedPacketBuffer mutex poisoned"); - - // Adjust seek position by base_offset for absolute seeks - let adjusted_pos = match pos { - SeekFrom::Start(offset) => { - // Absolute position from start - subtract base_offset - if offset >= base as u64 { - SeekFrom::Start(offset - base as u64) - } else { - // Seeking before our window - this is an error but we'll seek to start - tracing::warn!( - "WebM seek to {} before base_offset {}, clamping to start", - offset, - base - ); - SeekFrom::Start(0) - } +/// Unified buffer type used by the WebM muxer. +/// +/// - **Live** mode: wraps a [`SharedPacketBuffer`] (in-memory streaming, non-seek writer). +/// - **File** mode: wraps a [`FileBackedBuffer`] (temp file on disk, seekable writer). +/// +/// This enum allows the muxer's `run()` method to use a single generic code +/// path regardless of the streaming mode. +enum MuxBuffer { + Live(SharedPacketBuffer), + File(FileBackedBuffer), +} + +impl MuxBuffer { + fn take_data(&mut self) -> Option { + match self { + Self::Live(buf) => buf.take_data(), + Self::File(buf) => match buf.take_data() { + Ok(data) => data, + Err(e) => { + tracing::error!("Failed to read temp file data: {e}"); + None + }, }, - // Current and End are relative, no adjustment needed - SeekFrom::Current(offset) => SeekFrom::Current(offset), - SeekFrom::End(offset) => SeekFrom::End(offset), - }; + } + } +} - let result = cursor_guard.seek(adjusted_pos)?; - drop(cursor_guard); +impl Write for MuxBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + Self::Live(b) => b.write(buf), + Self::File(b) => b.write(buf), + } + } - // Return the absolute position (including base_offset) - Ok(result + base as u64) + fn flush(&mut self) -> std::io::Result<()> { + match self { + Self::Live(b) => b.flush(), + Self::File(b) => b.flush(), + } + } +} + +impl Seek for MuxBuffer { + fn seek(&mut self, pos: SeekFrom) -> std::io::Result { + match self { + Self::Live(b) => { + // Live mode uses non-seek writer; this should not be called. + // Provide a no-op implementation that returns the current position. + #[allow(clippy::expect_used)] + let mut state = b.state.lock().expect("SharedPacketBuffer mutex poisoned"); + let base = state.base_offset; + let adjusted_pos = match pos { + SeekFrom::Start(offset) => { + if offset >= base as u64 { + SeekFrom::Start(offset - base as u64) + } else { + tracing::warn!( + "WebM seek to {} before base_offset {}, clamping to start", + offset, + base + ); + SeekFrom::Start(0) + } + }, + SeekFrom::Current(offset) => SeekFrom::Current(offset), + SeekFrom::End(offset) => SeekFrom::End(offset), + }; + let result = state.cursor.seek(adjusted_pos)?; + drop(state); + Ok(result + base as u64) + }, + Self::File(b) => b.seek(pos), + } } } @@ -253,13 +400,18 @@ impl WebMStreamingMode { #[derive(Deserialize, Debug, JsonSchema)] #[serde(default)] pub struct WebMMuxerConfig { - /// Audio sample rate in Hz + /// Audio sample rate in Hz (used when an audio input is connected) pub sample_rate: u32, /// Number of audio channels (1 for mono, 2 for stereo) pub channels: u32, + /// Video width in pixels (required when a video input is connected) + pub video_width: u32, + /// Video height in pixels (required when a video input is connected) + pub video_height: u32, /// The number of bytes to buffer before flushing to the output. Defaults to 65536. pub chunk_size: usize, - /// Streaming mode: "live" for real-time streaming (no duration), "file" for complete files with duration (default) + /// Streaming mode: "live" for real-time streaming (no duration), "file" for complete files + /// with duration (default) pub streaming_mode: WebMStreamingMode, } @@ -268,13 +420,45 @@ impl Default for WebMMuxerConfig { Self { sample_rate: 48000, channels: 2, + video_width: 0, + video_height: 0, chunk_size: DEFAULT_CHUNK_SIZE, streaming_mode: WebMStreamingMode::default(), } } } -/// A node that muxes compressed Opus audio packets into a WebM container stream. +/// Track handles resolved during segment setup. +struct MuxTracks { + audio: Option, + video: Option, +} + +/// Builds the MIME content-type string based on which tracks are present. +const fn webm_content_type(has_audio: bool, has_video: bool) -> &'static str { + match (has_audio, has_video) { + (true, true) => "video/webm; codecs=\"vp9,opus\"", + (false, true) => "video/webm; codecs=\"vp9\"", + (true, false) => "audio/webm; codecs=\"opus\"", + // Shouldn't happen - at least one track is required - but provide a safe fallback. + (false, false) => "video/webm", + } +} + +/// A node that muxes encoded Opus audio and/or VP9 video packets into a WebM container stream. +/// +/// Input pins use generic names (`"in"`, `"in_1"`, …) — the media type carried by each +/// input is detected at runtime from the packet's `content_type` field, **not** from the +/// pin name. This keeps the node future-proof for additional track types (subtitles, +/// data channels, etc.) without requiring pin-name changes. +/// +/// Pin layout (determined by config): +/// - Default (no video dimensions): single pin `"in"` accepting audio **or** video. +/// - With `video_width`/`video_height` > 0: two pins `"in"` + `"in_1"`, each accepting +/// audio or video. The muxer will auto-detect which track type each pin carries. +/// +/// At least one input must be connected. When both are connected, audio and video frames +/// are interleaved by arrival order as required by the WebM/Matroska container. pub struct WebMMuxerNode { config: WebMMuxerConfig, } @@ -286,13 +470,48 @@ impl WebMMuxerNode { } #[async_trait] +#[allow(clippy::too_many_lines)] impl ProcessorNode for WebMMuxerNode { fn input_pins(&self) -> Vec { - vec![InputPin { - name: "in".to_string(), - accepts_types: vec![PacketType::OpusAudio], // Accepts Opus audio - cardinality: PinCardinality::One, - }] + // Each pin accepts both audio and video — the actual media type is detected + // at runtime from the packet content_type, not from the pin name. + let media_types = vec![ + PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), + PacketType::EncodedVideo(EncodedVideoFormat { + codec: VideoCodec::Vp9, + bitstream_format: None, + codec_private: None, + profile: None, + level: None, + }), + ]; + + let has_video = self.config.video_width > 0 && self.config.video_height > 0; + if has_video { + // Two generic inputs for audio + video (order is determined at runtime). + vec![ + InputPin { + name: "in".to_string(), + accepts_types: media_types.clone(), + cardinality: PinCardinality::One, + }, + InputPin { + name: "in_1".to_string(), + accepts_types: media_types, + cardinality: PinCardinality::One, + }, + ] + } else { + // Single generic input — backward compatible with `needs: encoder_node`. + vec![InputPin { + name: "in".to_string(), + accepts_types: media_types, + cardinality: PinCardinality::One, + }] + } } fn output_pins(&self) -> Vec { @@ -304,38 +523,129 @@ impl ProcessorNode for WebMMuxerNode { } fn content_type(&self) -> Option { - // MSE requires codec information in the MIME type - Some("audio/webm; codecs=\"opus\"".to_string()) + // This static hint is used before the node runs. + // We can only infer from config: if video dimensions are set, video is + // present. Audio presence is unknown at this stage, so we conservatively + // report only what we can confirm — the runtime `run()` method uses the + // actual connected tracks for the real content-type. + let has_video = self.config.video_width > 0 && self.config.video_height > 0; + // Without a way to know if audio will be connected, assume audio-only + // when no video dimensions are configured, and video-only when they are. + // Mixed audio+video pipelines will get the correct type at runtime. + let has_audio = !has_video; + Some(webm_content_type(has_audio, has_video).to_string()) } async fn run(self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { let node_name = context.output_sender.node_name().to_string(); state_helpers::emit_initializing(&context.state_tx, &node_name); tracing::info!("WebMMuxerNode starting"); + + // --- Classify generic inputs using connection-time type metadata --- + // + // Inputs use generic pin names ("in", "in_1", …). The graph builder + // populates `context.input_types` with the upstream output's + // [`PacketType`] for each connected pin, so we can determine whether a + // channel carries audio or video without inspecting any packets. + + if context.inputs.is_empty() { + let err_msg = "WebMMuxerNode requires at least one input (audio or video)".to_string(); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + + let mut audio_rx: Option> = None; + let mut video_rx: Option> = None; + + for (pin_name, rx) in context.inputs.drain() { + let is_video = context.input_types.get(&pin_name).is_some_and(|ty| { + matches!(ty, PacketType::EncodedVideo(_) | PacketType::RawVideo(_)) + }); + + if is_video { + if video_rx.is_some() { + let err_msg = format!( + "WebMMuxerNode: multiple video inputs detected (pin '{pin_name}'). \ + Only one video track is supported." + ); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + tracing::info!( + "WebMMuxerNode: pin '{pin_name}' classified as VIDEO (from connection type)" + ); + video_rx = Some(rx); + } else { + if audio_rx.is_some() { + let err_msg = format!( + "WebMMuxerNode: multiple audio inputs detected (pin '{pin_name}'). \ + Only one audio track is supported." + ); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + tracing::info!( + "WebMMuxerNode: pin '{pin_name}' classified as AUDIO (from connection type)" + ); + audio_rx = Some(rx); + } + } + + let has_audio = audio_rx.is_some(); + let has_video = video_rx.is_some(); + + if !has_audio && !has_video { + let err_msg = + "WebMMuxerNode: no connected inputs could be classified as audio or video" + .to_string(); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + state_helpers::emit_running(&context.state_tx, &node_name); - let mut input_rx = context.take_input("in")?; - let mut packet_count = 0u64; - // Stats tracking + tracing::info!("WebMMuxerNode tracks: audio={}, video={}", has_audio, has_video); + + let content_type_str: Cow<'static, str> = + Cow::Borrowed(webm_content_type(has_audio, has_video)); + + let mut packet_count = 0u64; let mut stats_tracker = NodeStatsTracker::new(node_name.clone(), context.stats_tx.clone()); - // In Live mode we use a non-seek writer, so we can drain bytes out without keeping - // any history (zero-copy streaming). In File mode we must keep the whole buffer - // because we only emit bytes once the segment is finalized. - let shared_buffer = match self.config.streaming_mode { - WebMStreamingMode::Live => SharedPacketBuffer::new_streaming(), - WebMStreamingMode::File => SharedPacketBuffer::new_with_window(usize::MAX), + // In Live mode we use a non-seek, in-memory streaming buffer so bytes + // can be drained incrementally without keeping history. In File mode + // we use a temp file on disk so the muxer can seek/backpatch without + // accumulating the entire output in memory. + // + // For Live mode we keep a cloned handle (`live_flush_handle`) so the + // receive loop can drain bytes while the Writer owns the buffer. + let (mux_buffer, live_flush_handle) = match self.config.streaming_mode { + WebMStreamingMode::Live => { + let spb = SharedPacketBuffer::new_streaming(); + let flush_handle = spb.clone(); + (MuxBuffer::Live(spb), Some(flush_handle)) + }, + WebMStreamingMode::File => { + let fb = FileBackedBuffer::new().map_err(|e| { + let err_msg = format!("Failed to create temp file for WebM file mode: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + StreamKitError::Runtime(err_msg) + })?; + (MuxBuffer::File(fb), None) + }, }; - // Create writer with shared buffer. + // Create writer with the unified buffer. // - // Important: In `Live` mode we must avoid any backwards seeking/backpatching while bytes - // are being streamed to the client. Using a non-seek writer forces libwebm to produce a - // forward-only stream (unknown sizes/no cues), which is required for MSE consumers like - // Firefox that are less tolerant of inconsistent metadata during progressive append. + // Important: In `Live` mode we must avoid any backwards seeking/backpatching while + // bytes are being streamed to the client. Using a non-seek writer forces libwebm to + // produce a forward-only stream (unknown sizes/no cues), which is required for MSE + // consumers like Firefox that are less tolerant of inconsistent metadata during + // progressive append. In File mode we use a seekable writer so libwebm can + // back-patch segment sizes and write cues. let writer = match self.config.streaming_mode { - WebMStreamingMode::Live => Writer::new_non_seek(shared_buffer.clone()), - WebMStreamingMode::File => Writer::new(shared_buffer.clone()), + WebMStreamingMode::Live => Writer::new_non_seek(mux_buffer), + WebMStreamingMode::File => Writer::new(mux_buffer), }; // Create WebM segment builder @@ -345,7 +655,6 @@ impl ProcessorNode for WebMMuxerNode { StreamKitError::Runtime(err_msg) })?; - // Set streaming mode based on configuration let builder = builder.set_mode(self.config.streaming_mode.as_segment_mode()).map_err(|e| { let err_msg = format!("Failed to set streaming mode: {e}"); @@ -353,188 +662,321 @@ impl ProcessorNode for WebMMuxerNode { StreamKitError::Runtime(err_msg) })?; - // Add audio track for Opus - let opus_private = opus_head_codec_private(self.config.sample_rate, self.config.channels) - .map_err(|e| { - let err_msg = format!("Failed to build OpusHead codec private: {e}"); - state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); - StreamKitError::Runtime(err_msg) - })?; + // -- Add tracks conditionally -- - let (builder, audio_track) = builder - .add_audio_track( - self.config.sample_rate, - self.config.channels, - AudioCodecId::Opus, - None, // Let the library assign track number - ) - .map_err(|e| { - let err_msg = format!("Failed to add audio track: {e}"); - state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); - StreamKitError::Runtime(err_msg) - })?; + let mut tracks = MuxTracks { audio: None, video: None }; - let builder = builder.set_codec_private(audio_track, &opus_private).map_err(|e| { - let err_msg = format!("Failed to set Opus codec private: {e}"); - state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); - StreamKitError::Runtime(err_msg) - })?; + // --- Resolve video dimensions ----------------------------------------- + // + // When `video_width` / `video_height` are both 0 (the default) and a + // video input is connected, we auto-detect the dimensions from the + // first VP9 keyframe. This avoids requiring the user to manually + // keep the muxer config in sync with the upstream encoder / compositor. + // + // The first video packet is buffered so it can be replayed through the + // normal receive loop after the segment is built. - // Build the segment - // Note: The WebM header is not written until the first frame is added, - // so we flush it after adding the first frame below - let mut segment = builder.build(); + let mut first_video_packet: Option<(Bytes, Option)> = None; - let mut clock = MediaClock::new(0); - let mut header_sent = false; + let (video_width, video_height) = if has_video { + let mut w = self.config.video_width; + let mut h = self.config.video_height; - tracing::info!("WebM segment built, entering receive loop to process incoming packets"); - while let Some(packet) = context.recv_with_cancellation(&mut input_rx).await { - if let Packet::Binary { data, metadata, .. } = packet { - packet_count += 1; - stats_tracker.received(); - - // tracing::debug!( - // "WebMMuxer received packet #{}, {} bytes", - // packet_count, - // data.len() - // ); - - // Calculate timestamp from metadata (microseconds). - let incoming_ts_us = metadata.as_ref().and_then(|m| m.timestamp_us); - let incoming_duration_us = metadata - .as_ref() - .and_then(|m| m.duration_us) - .or(Some(DEFAULT_FRAME_DURATION_US)); - - if let Some(ts) = incoming_ts_us { - clock.seed_from_timestamp_us(ts); - } else if clock.timestamp_us() == 0 { - clock.seed_from_timestamp_us(0); + if w == 0 || h == 0 { + // Auto-detect: wait for the first video packet and parse its VP9 header. + tracing::info!( + "WebMMuxerNode: video_width/video_height not configured, \ + auto-detecting from first VP9 keyframe" + ); + + let first = match video_rx.as_mut() { + Some(rx) => rx.recv().await, + None => None, + }; + + if let Some(Packet::Binary { data, metadata, .. }) = first { + if let Some((detected_w, detected_h)) = parse_vp9_keyframe_dimensions(&data) { + tracing::info!( + "Auto-detected video dimensions: {}x{}", + detected_w, + detected_h + ); + w = detected_w; + h = detected_h; + } else { + let err_msg = "WebMMuxerNode: failed to parse VP9 keyframe dimensions \ + from first video packet (is the upstream encoder VP9?)" + .to_string(); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + first_video_packet = Some((data, metadata)); + } else { + let err_msg = + "WebMMuxerNode: video input closed before sending any packets".to_string(); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); } + } - let presentation_ts_us = incoming_ts_us.unwrap_or_else(|| clock.timestamp_us()); + if w == 0 || h == 0 { + let err_msg = "WebMMuxerNode: video dimensions could not be determined".to_string(); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + (w, h) + } else { + (0, 0) + }; - // Advance clock for next frame - clock.advance_by_duration_us(incoming_duration_us, DEFAULT_FRAME_DURATION_US); + // Video track is added first so that the segment header lists it prominently + // for players that inspect the first track. + let builder = if has_video { + let (builder, vt) = builder + .add_video_track(video_width, video_height, VideoCodecId::VP9, None) + .map_err(|e| { + let err_msg = format!("Failed to add video track: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + StreamKitError::Runtime(err_msg) + })?; + tracks.video = Some(vt); + tracing::info!("Added VP9 video track ({}x{})", video_width, video_height); + builder + } else { + builder + }; - let current_timestamp_ns = presentation_ts_us.saturating_mul(1000); + let builder = if has_audio { + let opus_private = + opus_head_codec_private(self.config.sample_rate, self.config.channels).map_err( + |e| { + let err_msg = format!("Failed to build OpusHead codec private: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + StreamKitError::Runtime(err_msg) + }, + )?; + + let (builder, at) = builder + .add_audio_track( + self.config.sample_rate, + self.config.channels, + AudioCodecId::Opus, + None, + ) + .map_err(|e| { + let err_msg = format!("Failed to add audio track: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + StreamKitError::Runtime(err_msg) + })?; - let output_metadata = Some(PacketMetadata { - timestamp_us: Some(presentation_ts_us), - duration_us: incoming_duration_us, - sequence: metadata.as_ref().and_then(|m| m.sequence), - }); + let builder = builder.set_codec_private(at, &opus_private).map_err(|e| { + let err_msg = format!("Failed to set Opus codec private: {e}"); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + StreamKitError::Runtime(err_msg) + })?; - // For audio, all frames are effectively "keyframes" (can start playback from any point) - let is_keyframe = true; + tracks.audio = Some(at); + tracing::info!( + "Added Opus audio track ({}Hz, {} ch)", + self.config.sample_rate, + self.config.channels + ); + builder + } else { + builder + }; - // tracing::debug!( - // "Adding packet #{} to WebM segment (timestamp: {}ns)", - // packet_count, - // current_timestamp_ns - // ); + let mut segment = builder.build(); - // Add frame to segment - if let Err(e) = - segment.add_frame(audio_track, &data, current_timestamp_ns, is_keyframe) - { - stats_tracker.errored(); - stats_tracker.maybe_send(); - let err_msg = format!("Failed to add frame to segment: {e}"); - state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); - return Err(StreamKitError::Runtime(err_msg)); - } + let mut audio_clock = MediaClock::new(0); + let mut video_clock = MediaClock::new(0); + let mut header_sent = false; - // tracing::debug!( - // "Packet #{} added to WebM segment successfully", - // packet_count - // ); + // Monotonic timestamp guard: libwebm requires that timestamps across all tracks + // are non-decreasing. We track the last written timestamp and clamp if needed. + let mut last_written_ns: u64 = 0; - // After adding the first frame, the WebM header has been written - flush it immediately - if !header_sent && matches!(self.config.streaming_mode, WebMStreamingMode::Live) { - let header_data = shared_buffer.take_data(); + tracing::info!("WebM segment built, entering receive loop to process incoming packets"); - if let Some(data) = header_data { - tracing::info!( - "Sending WebM header + first frame ({} bytes), first 20 bytes: {:?}", - data.len(), - &data[..data.len().min(20)] - ); - if context - .output_sender - .send( - "out", - Packet::Binary { - data, - content_type: Some(Cow::Borrowed( - "audio/webm; codecs=\"opus\"", - )), - metadata: None, - }, - ) - .await - .is_err() - { - tracing::debug!("Output channel closed, stopping node"); - state_helpers::emit_stopped( - &context.state_tx, - &node_name, - "output_closed", - ); - return Ok(()); - } - stats_tracker.sent(); - header_sent = true; - } + // -- Receive loop: multiplex audio + video inputs -- + + let mut audio_done = !has_audio; + let mut video_done = !has_video; + + // If we buffered the first video packet for dimension detection, replay + // it through the normal mux path before entering the receive loop. + if let Some((data, metadata)) = first_video_packet.take() { + if let Some(video_track) = tracks.video { + let is_keyframe = metadata.as_ref().and_then(|m| m.keyframe).unwrap_or(true); + if mux_frame( + &data, + metadata.as_ref(), + video_track, + is_keyframe, + DEFAULT_VIDEO_FRAME_DURATION_US, + &mut video_clock, + &mut last_written_ns, + &mut segment, + &mut context, + live_flush_handle.as_ref(), + &content_type_str, + &mut header_sent, + &mut stats_tracker, + &node_name, + &mut packet_count, + ) + .await? + { + video_done = true; } + } + } - // In Live mode, flush after every frame for true streaming - // In File mode, keep everything for proper duration/seeking - if header_sent && matches!(self.config.streaming_mode, WebMStreamingMode::Live) { - // Flush any buffered data immediately for low-latency streaming - if let Some(data) = shared_buffer.take_data() { - tracing::trace!("Flushing {} bytes to output", data.len()); - if context - .output_sender - .send( - "out", - Packet::Binary { - data, - content_type: Some(Cow::Borrowed( - "audio/webm; codecs=\"opus\"", - )), - metadata: output_metadata.clone(), - }, - ) - .await - .is_err() - { - tracing::debug!("Output channel closed, stopping node"); - break; + while !audio_done || !video_done { + enum MuxFrame { + Audio(Bytes, Option), + Video(Bytes, Option), + AudioClosed, + VideoClosed, + } + + let frame = if audio_done { + // Only video remains + match video_rx.as_mut() { + Some(rx) => match rx.recv().await { + Some(Packet::Binary { data, metadata, .. }) => { + MuxFrame::Video(data, metadata) + }, + Some(_) => continue, + None => MuxFrame::VideoClosed, + }, + None => break, + } + } else if video_done { + // Only audio remains + match audio_rx.as_mut() { + Some(rx) => match rx.recv().await { + Some(Packet::Binary { data, metadata, .. }) => { + MuxFrame::Audio(data, metadata) + }, + Some(_) => continue, + None => MuxFrame::AudioClosed, + }, + None => break, + } + } else { + // Both active - use select to receive from whichever is ready first + let audio_rx_ref = audio_rx.as_mut(); + let video_rx_ref = video_rx.as_mut(); + match (audio_rx_ref, video_rx_ref) { + (Some(a_rx), Some(v_rx)) => { + tokio::select! { + biased; // prefer audio first for stable ordering + maybe_audio = a_rx.recv() => { + match maybe_audio { + Some(Packet::Binary { data, metadata, .. }) => { + MuxFrame::Audio(data, metadata) + }, + Some(_) => continue, + None => MuxFrame::AudioClosed, + } + } + maybe_video = v_rx.recv() => { + match maybe_video { + Some(Packet::Binary { data, metadata, .. }) => { + MuxFrame::Video(data, metadata) + }, + Some(_) => continue, + None => MuxFrame::VideoClosed, + } + } } - stats_tracker.sent(); - } + }, + _ => break, } - - stats_tracker.maybe_send(); + }; + + match frame { + MuxFrame::AudioClosed => { + tracing::info!("WebMMuxerNode audio input closed"); + audio_done = true; + }, + MuxFrame::VideoClosed => { + tracing::info!("WebMMuxerNode video input closed"); + video_done = true; + }, + MuxFrame::Audio(data, metadata) => { + let Some(audio_track) = tracks.audio else { + continue; + }; + // Audio frames are always keyframes. + if mux_frame( + &data, + metadata.as_ref(), + audio_track, + true, + DEFAULT_FRAME_DURATION_US, + &mut audio_clock, + &mut last_written_ns, + &mut segment, + &mut context, + live_flush_handle.as_ref(), + &content_type_str, + &mut header_sent, + &mut stats_tracker, + &node_name, + &mut packet_count, + ) + .await? + { + break; + } + }, + MuxFrame::Video(data, metadata) => { + let Some(video_track) = tracks.video else { + continue; + }; + let is_keyframe = metadata.as_ref().and_then(|m| m.keyframe).unwrap_or(false); + if mux_frame( + &data, + metadata.as_ref(), + video_track, + is_keyframe, + DEFAULT_VIDEO_FRAME_DURATION_US, + &mut video_clock, + &mut last_written_ns, + &mut segment, + &mut context, + live_flush_handle.as_ref(), + &content_type_str, + &mut header_sent, + &mut stats_tracker, + &node_name, + &mut packet_count, + ) + .await? + { + break; + } + }, } } tracing::info!( - "WebMMuxerNode input stream closed, processed {} packets total", + "WebMMuxerNode input streams closed, processed {} packets total", packet_count ); - // Finalize the segment - let _writer = segment.finalize(None).map_err(|_e| { + // Finalize the segment and recover the buffer. + let writer = segment.finalize(None).map_err(|_e| { let err_msg = "Failed to finalize WebM segment".to_string(); state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); StreamKitError::Runtime(err_msg) })?; + let mut mux_buffer = writer.into_inner(); // Flush any remaining data from the buffer - if let Some(data) = shared_buffer.take_data() { + if let Some(data) = mux_buffer.take_data() { tracing::debug!("Writing final data, buffer size: {} bytes", data.len()); if context .output_sender @@ -542,7 +984,7 @@ impl ProcessorNode for WebMMuxerNode { "out", Packet::Binary { data, - content_type: Some(Cow::Borrowed("audio/webm; codecs=\"opus\"")), + content_type: Some(content_type_str.clone()), metadata: None, }, ) @@ -550,7 +992,6 @@ impl ProcessorNode for WebMMuxerNode { .is_err() { tracing::debug!("Output channel closed during final flush"); - // Don't return error, we're already shutting down } else { stats_tracker.sent(); } @@ -564,6 +1005,163 @@ impl ProcessorNode for WebMMuxerNode { } } +/// Timestamps, clocks, and writes a single frame (audio or video) to the WebM +/// segment, then flushes any buffered output. +/// +/// Returns `Ok(true)` if the output channel is closed (caller should stop), +/// `Ok(false)` to continue, or `Err` on fatal errors. +#[allow(clippy::too_many_arguments)] +#[allow(clippy::ptr_arg)] // content_type is cloned as Cow<'static, str> for Packet; &str would force allocation +async fn mux_frame( + data: &[u8], + metadata: Option<&PacketMetadata>, + track: impl Into, + is_keyframe: bool, + default_duration_us: u64, + clock: &mut streamkit_core::timing::MediaClock, + last_written_ns: &mut u64, + segment: &mut webm::mux::Segment, + context: &mut NodeContext, + live_buffer: Option<&SharedPacketBuffer>, + content_type: &Cow<'static, str>, + header_sent: &mut bool, + stats_tracker: &mut NodeStatsTracker, + node_name: &str, + packet_count: &mut u64, +) -> Result { + *packet_count += 1; + stats_tracker.received(); + + let incoming_ts_us = metadata.and_then(|m| m.timestamp_us); + let incoming_duration_us = metadata.and_then(|m| m.duration_us).or(Some(default_duration_us)); + + if let Some(ts) = incoming_ts_us { + clock.seed_from_timestamp_us(ts); + } else if clock.timestamp_us() == 0 { + clock.seed_from_timestamp_us(0); + } + + let presentation_ts_us = incoming_ts_us.unwrap_or_else(|| clock.timestamp_us()); + clock.advance_by_duration_us(incoming_duration_us, default_duration_us); + + let mut timestamp_ns = presentation_ts_us.saturating_mul(1000); + if timestamp_ns < *last_written_ns { + timestamp_ns = *last_written_ns; + } + + if let Err(e) = segment.add_frame(track, data, timestamp_ns, is_keyframe) { + stats_tracker.errored(); + stats_tracker.maybe_send(); + let err_msg = format!("Failed to add frame to segment: {e}"); + state_helpers::emit_failed(&context.state_tx, node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + + *last_written_ns = timestamp_ns; + + let output_metadata = Some(PacketMetadata { + timestamp_us: Some(presentation_ts_us), + duration_us: incoming_duration_us, + sequence: metadata.and_then(|m| m.sequence), + keyframe: Some(is_keyframe), + }); + + let stopped = flush_output( + context, + live_buffer, + content_type, + output_metadata, + header_sent, + stats_tracker, + node_name, + ) + .await?; + + stats_tracker.maybe_send(); + Ok(stopped) +} + +/// Flushes buffered WebM data to the output sender. +/// +/// In **Live** mode, bytes are drained incrementally after every frame to keep +/// memory bounded and enable real-time streaming. In **File** mode the data +/// lives on disk in a temp file and is only read back once after finalization +/// (handled by the caller), so this function is a no-op. +/// +/// `live_buffer` is `Some` only in Live mode — it is the cloned +/// `SharedPacketBuffer` handle that shares the same `Arc>` backing +/// store as the `MuxBuffer::Live` variant owned by the Writer. +/// +/// Returns `Ok(true)` if the output channel is closed (node should stop), +/// `Ok(false)` to continue, or `Err` on fatal errors. +#[allow(clippy::ptr_arg)] +async fn flush_output( + context: &mut NodeContext, + live_buffer: Option<&SharedPacketBuffer>, + content_type: &Cow<'static, str>, + output_metadata: Option, + header_sent: &mut bool, + stats_tracker: &mut NodeStatsTracker, + node_name: &str, +) -> Result { + // In File mode there is no live buffer — skip all intermediate flushes. + // The data will be read from the temp file after `segment.finalize()`. + let Some(shared_buffer) = live_buffer else { + return Ok(false); + }; + + if !*header_sent { + if let Some(data) = shared_buffer.take_data() { + tracing::info!("Sending WebM header + first frame ({} bytes)", data.len(),); + if context + .output_sender + .send( + "out", + Packet::Binary { + data, + content_type: Some(content_type.clone()), + metadata: None, + }, + ) + .await + .is_err() + { + tracing::debug!("Output channel closed, stopping node"); + state_helpers::emit_stopped(&context.state_tx, node_name, "output_closed"); + return Ok(true); + } + stats_tracker.sent(); + *header_sent = true; + } + } + + // Flush any accumulated bytes after the header has been sent. + if *header_sent { + if let Some(data) = shared_buffer.take_data() { + tracing::trace!("Flushing {} bytes to output", data.len()); + if context + .output_sender + .send( + "out", + Packet::Binary { + data, + content_type: Some(content_type.clone()), + metadata: output_metadata, + }, + ) + .await + .is_err() + { + tracing::debug!("Output channel closed, stopping node"); + return Ok(true); + } + stats_tracker.sent(); + } + } + + Ok(false) +} + use schemars::schema_for; use streamkit_core::{config_helpers, registry::StaticPins}; @@ -588,8 +1186,9 @@ pub fn register_webm_nodes(registry: &mut NodeRegistry) { StaticPins { inputs: default_muxer.input_pins(), outputs: default_muxer.output_pins() }, vec!["containers".to_string(), "webm".to_string()], false, - "Muxes Opus audio into a WebM container. \ - Produces streamable WebM/Opus output compatible with web browsers.", + "Muxes Opus audio and/or VP9 video into a WebM container. \ + Produces streamable WebM output compatible with web browsers. \ + Supports audio-only, video-only, or combined audio+video muxing.", ); } } diff --git a/crates/nodes/src/core/file_read.rs b/crates/nodes/src/core/file_read.rs index 2aa7f085..5841bacf 100644 --- a/crates/nodes/src/core/file_read.rs +++ b/crates/nodes/src/core/file_read.rs @@ -231,6 +231,7 @@ mod tests { let context = NodeContext { inputs: HashMap::new(), + input_types: HashMap::new(), control_rx, output_sender, batch_size: 32, @@ -241,6 +242,8 @@ mod tests { cancellation_token: None, pin_management_rx: None, // Test contexts don't support dynamic pins audio_pool: None, + video_pool: None, + view_data_tx: None, }; // Create and run node diff --git a/crates/nodes/src/core/file_write.rs b/crates/nodes/src/core/file_write.rs index 847a5c1a..359b8e4d 100644 --- a/crates/nodes/src/core/file_write.rs +++ b/crates/nodes/src/core/file_write.rs @@ -193,6 +193,7 @@ mod tests { let context = NodeContext { inputs, + input_types: HashMap::new(), control_rx, output_sender, batch_size: 32, @@ -203,6 +204,8 @@ mod tests { cancellation_token: None, pin_management_rx: None, // Test contexts don't support dynamic pins audio_pool: None, + video_pool: None, + view_data_tx: None, }; // Create and run node @@ -273,6 +276,7 @@ mod tests { let context = NodeContext { inputs, + input_types: HashMap::new(), control_rx, output_sender, batch_size: 32, @@ -283,6 +287,8 @@ mod tests { cancellation_token: None, pin_management_rx: None, // Test contexts don't support dynamic pins audio_pool: None, + video_pool: None, + view_data_tx: None, }; // Create and run node with small chunk size for testing diff --git a/crates/nodes/src/core/pacer.rs b/crates/nodes/src/core/pacer.rs index db6bcbe0..d836e59b 100644 --- a/crates/nodes/src/core/pacer.rs +++ b/crates/nodes/src/core/pacer.rs @@ -111,6 +111,11 @@ impl PacerNode { // Fallback: calculate from AudioFrame Self::calculate_audio_duration(frame) }, + Packet::Video(frame) => frame + .metadata + .as_ref() + .and_then(|m| m.duration_us) + .map_or(Duration::ZERO, Duration::from_micros), Packet::Binary { metadata, .. } => { // Use metadata if available metadata @@ -139,6 +144,7 @@ impl PacerNode { fn packet_metadata(packet: &Packet) -> Option<&PacketMetadata> { match packet { Packet::Audio(frame) => frame.metadata.as_ref(), + Packet::Video(frame) => frame.metadata.as_ref(), Packet::Binary { metadata, .. } => metadata.as_ref(), Packet::Custom(custom) => custom.metadata.as_ref(), Packet::Transcription(transcription) => transcription.metadata.as_ref(), @@ -487,6 +493,7 @@ mod tests { let context = NodeContext { inputs, + input_types: HashMap::new(), control_rx, output_sender, batch_size: 32, @@ -497,6 +504,8 @@ mod tests { cancellation_token: None, pin_management_rx: None, // Test contexts don't support dynamic pins audio_pool: None, + video_pool: None, + view_data_tx: None, }; // Create node with very fast speed to minimize test time @@ -517,6 +526,7 @@ mod tests { timestamp_us: None, duration_us: Some(1_000), // 1ms sequence: Some(i), + keyframe: None, }), }) .await diff --git a/crates/nodes/src/core/script.rs b/crates/nodes/src/core/script.rs index 74126617..b76b30d9 100644 --- a/crates/nodes/src/core/script.rs +++ b/crates/nodes/src/core/script.rs @@ -494,6 +494,44 @@ impl ScriptNode { obj.set("metadata", metadata) .map_err(|e| StreamKitError::Runtime(format!("Failed to set metadata: {e}")))?; }, + Packet::Video(frame) => { + obj.set("type", "Video") + .map_err(|e| StreamKitError::Runtime(format!("Failed to set type: {e}")))?; + + let metadata = rquickjs::Object::new(ctx.clone()).map_err(|e| { + StreamKitError::Runtime(format!("Failed to create metadata: {e}")) + })?; + metadata + .set("width", frame.width) + .map_err(|e| StreamKitError::Runtime(format!("Failed to set width: {e}")))?; + metadata + .set("height", frame.height) + .map_err(|e| StreamKitError::Runtime(format!("Failed to set height: {e}")))?; + metadata.set("pixel_format", format!("{:?}", frame.pixel_format)).map_err(|e| { + StreamKitError::Runtime(format!("Failed to set pixel_format: {e}")) + })?; + metadata + .set("bytes", frame.data.len()) + .map_err(|e| StreamKitError::Runtime(format!("Failed to set bytes: {e}")))?; + if let Some(timestamp_us) = frame.metadata.as_ref().and_then(|m| m.timestamp_us) { + metadata.set("timestamp_us", timestamp_us).map_err(|e| { + StreamKitError::Runtime(format!("Failed to set timestamp_us: {e}")) + })?; + } + if let Some(duration_us) = frame.metadata.as_ref().and_then(|m| m.duration_us) { + metadata.set("duration_us", duration_us).map_err(|e| { + StreamKitError::Runtime(format!("Failed to set duration_us: {e}")) + })?; + } + if let Some(keyframe) = frame.metadata.as_ref().and_then(|m| m.keyframe) { + metadata.set("keyframe", keyframe).map_err(|e| { + StreamKitError::Runtime(format!("Failed to set keyframe: {e}")) + })?; + } + + obj.set("metadata", metadata) + .map_err(|e| StreamKitError::Runtime(format!("Failed to set metadata: {e}")))?; + }, Packet::Transcription(transcription) => { obj.set("type", "Transcription") @@ -2369,6 +2407,7 @@ mod tests { timestamp_us: Some(1_000_000), duration_us: None, sequence: None, + keyframe: None, }), }))) .await @@ -2388,6 +2427,7 @@ mod tests { timestamp_us: Some(3_000_000), duration_us: None, sequence: None, + keyframe: None, }), }))) .await diff --git a/crates/nodes/src/core/telemetry_out.rs b/crates/nodes/src/core/telemetry_out.rs index 8ad7e611..8afede77 100644 --- a/crates/nodes/src/core/telemetry_out.rs +++ b/crates/nodes/src/core/telemetry_out.rs @@ -79,6 +79,7 @@ impl TelemetryOutNode { fn should_tap_packet_type(&self, packet: &Packet) -> bool { let type_name = match packet { Packet::Audio(_) => "Audio", + Packet::Video(_) => "Video", Packet::Transcription(_) => "Transcription", Packet::Custom(_) => "Custom", Packet::Binary { .. } => "Binary", @@ -211,7 +212,7 @@ impl ProcessorNode for TelemetryOutNode { serde_json::json!({ "size_bytes": data.len(), "has_metadata": metadata.is_some() }), ); }, - Packet::Audio(_) => { + Packet::Video(_) | Packet::Audio(_) => { // Intentionally no audio-level telemetry here to avoid noise; use `core::telemetry_tap` if needed. }, } diff --git a/crates/nodes/src/core/telemetry_tap.rs b/crates/nodes/src/core/telemetry_tap.rs index 63d9d9bf..2bf59303 100644 --- a/crates/nodes/src/core/telemetry_tap.rs +++ b/crates/nodes/src/core/telemetry_tap.rs @@ -131,6 +131,7 @@ impl TelemetryTapNode { fn should_tap_packet_type(&self, packet: &Packet) -> bool { let type_name = match packet { Packet::Audio(_) => "Audio", + Packet::Video(_) => "Video", Packet::Transcription(_) => "Transcription", Packet::Custom(_) => "Custom", Packet::Binary { .. } => "Binary", @@ -334,6 +335,18 @@ impl ProcessorNode for TelemetryTapNode { }), ); }, + Packet::Video(frame) => { + telemetry.emit( + "video.received", + serde_json::json!({ + "width": frame.width, + "height": frame.height, + "pixel_format": format!("{:?}", frame.pixel_format), + "size_bytes": frame.data.len(), + "has_metadata": frame.metadata.is_some(), + }), + ); + }, } } diff --git a/crates/nodes/src/lib.rs b/crates/nodes/src/lib.rs index 22461277..724f70b4 100644 --- a/crates/nodes/src/lib.rs +++ b/crates/nodes/src/lib.rs @@ -6,10 +6,10 @@ use streamkit_core::NodeRegistry; // Declare the top-level feature modules directly. pub mod audio; -pub mod core; -// pub mod video; pub mod containers; +pub mod core; pub mod transport; +pub mod video; // Shared utilities pub mod streaming_utils; @@ -32,7 +32,7 @@ pub fn register_nodes( audio::register_audio_nodes(registry); containers::register_container_nodes(registry); transport::register_transport_nodes(registry); - // video::register_video_nodes(registry); + video::register_video_nodes(registry); tracing::info!("Finished registering built-in nodes."); } @@ -45,7 +45,7 @@ pub fn register_nodes(registry: &mut NodeRegistry) { audio::register_audio_nodes(registry); containers::register_container_nodes(registry); transport::register_transport_nodes(registry); - // video::register_video_nodes(registry); + video::register_video_nodes(registry); tracing::info!("Finished registering built-in nodes."); } diff --git a/crates/nodes/src/test_utils.rs b/crates/nodes/src/test_utils.rs index fd25aeb5..385abe37 100644 --- a/crates/nodes/src/test_utils.rs +++ b/crates/nodes/src/test_utils.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use streamkit_core::node::{NodeContext, OutputRouting, OutputSender, RoutedPacketMessage}; use streamkit_core::state::NodeStateUpdate; -use streamkit_core::types::Packet; +use streamkit_core::types::{Packet, PixelFormat, VideoFrame, VideoLayout}; use tokio::sync::mpsc; /// Creates a test NodeContext with mock channels @@ -28,6 +28,7 @@ pub fn create_test_context( let context = NodeContext { inputs, + input_types: HashMap::new(), control_rx, output_sender, batch_size, @@ -38,6 +39,8 @@ pub fn create_test_context( cancellation_token: None, pin_management_rx: Some(pin_mgmt_rx), // Provide channel for dynamic pins support audio_pool: None, + video_pool: None, + view_data_tx: None, }; (context, mock_sender, state_rx) @@ -133,6 +136,49 @@ pub fn create_test_binary_packet(data: Vec) -> Packet { Packet::Binary { data: bytes::Bytes::from(data), content_type: None, metadata: None } } +/// Helper to create a simple video frame for testing. +/// +/// # Panics +/// +/// Panics if the width/height/pixel-format combination is not accepted by +/// [`VideoFrame::new`] (e.g. zero dimensions or a mismatch between the computed +/// layout size and the allocated buffer). Callers in tests pick these values +/// deliberately, so a panic indicates a bug in the test itself. +#[allow(clippy::expect_used)] +pub fn create_test_video_frame( + width: u32, + height: u32, + pixel_format: PixelFormat, + fill_value: u8, +) -> VideoFrame { + let layout = VideoLayout::packed(width, height, pixel_format); + let mut data = vec![fill_value; layout.total_bytes()]; + + if pixel_format == PixelFormat::I420 || pixel_format == PixelFormat::Nv12 { + // Neutral chroma for predictable decoder output. + // Works for both I420 (separate U/V planes) and NV12 (interleaved UV plane): + // filling with 128 produces neutral grey regardless of interleaving. + for plane in layout.planes().iter().skip(1) { + let start = plane.offset; + let end = start + plane.stride * plane.height as usize; + data[start..end].fill(128); + } + } + + VideoFrame::new(width, height, pixel_format, data) + .expect("test video frame dimensions/format should be valid") +} + +/// Helper to create a simple video packet for testing. +pub fn create_test_video_packet( + width: u32, + height: u32, + pixel_format: PixelFormat, + fill_value: u8, +) -> Packet { + Packet::Video(create_test_video_frame(width, height, pixel_format, fill_value)) +} + /// Helper to extract audio data from a packet pub fn extract_audio_data(packet: &Packet) -> Option<&[f32]> { match packet { diff --git a/crates/nodes/src/transport/http.rs b/crates/nodes/src/transport/http.rs index 4cb2e90d..ce484016 100644 --- a/crates/nodes/src/transport/http.rs +++ b/crates/nodes/src/transport/http.rs @@ -378,6 +378,7 @@ mod tests { let context = NodeContext { inputs: HashMap::new(), + input_types: HashMap::new(), control_rx, output_sender, batch_size: 32, @@ -388,6 +389,8 @@ mod tests { cancellation_token: None, pin_management_rx: None, // Test contexts don't support dynamic pins audio_pool: None, + video_pool: None, + view_data_tx: None, }; // Create and run node with small chunk size for testing diff --git a/crates/nodes/src/transport/moq/mod.rs b/crates/nodes/src/transport/moq/mod.rs index e187743a..0b246517 100644 --- a/crates/nodes/src/transport/moq/mod.rs +++ b/crates/nodes/src/transport/moq/mod.rs @@ -51,27 +51,6 @@ fn shared_insecure_client() -> Result { } } -/// Serialize a catalog to JSON with `priority` fields injected into `video` and `audio`. -/// -/// The published `@moq/hang` JS client (0.1.2) still requires `priority` in the catalog -/// schema, but the Rust `hang` 0.13.0 crate removed it from the structs. -/// The upstream JS source has already dropped the requirement, but a new npm release -/// hasn't been published yet. This shim keeps the two sides compatible. -pub(super) fn catalog_to_json(catalog: &hang::catalog::Catalog) -> Result { - let mut value = serde_json::to_value(catalog) - .map_err(|e| StreamKitError::Runtime(format!("Failed to serialize catalog: {e}")))?; - - if let Some(video) = value.get_mut("video").and_then(|v| v.as_object_mut()) { - video.entry("priority").or_insert(serde_json::json!(60)); - } - if let Some(audio) = value.get_mut("audio").and_then(|v| v.as_object_mut()) { - audio.entry("priority").or_insert(serde_json::json!(80)); - } - - serde_json::to_string(&value) - .map_err(|e| StreamKitError::Runtime(format!("Failed to serialize catalog: {e}"))) -} - pub(super) fn redact_url_str_for_logs(raw: &str) -> String { raw.parse::().map_or_else( |_| raw.split(['?', '#']).next().unwrap_or(raw).to_string(), @@ -144,7 +123,7 @@ pub fn register_moq_nodes(registry: &mut NodeRegistry) { vec!["transport".to_string(), "moq".to_string(), "dynamic".to_string()], false, "Subscribes to a Media over QUIC (MoQ) broadcast. \ - Receives Opus audio from a remote publisher over WebTransport.", + Receives encoded Opus audio from a remote publisher over WebTransport.", ); let default_moq_push = MoqPushNode::new(MoqPushConfig::default()); @@ -163,7 +142,7 @@ pub fn register_moq_nodes(registry: &mut NodeRegistry) { vec!["transport".to_string(), "moq".to_string(), "dynamic".to_string()], false, "Publishes audio to a Media over QUIC (MoQ) broadcast. \ - Sends Opus audio to subscribers over WebTransport.", + Sends encoded Opus audio to subscribers over WebTransport.", ); let default_moq_peer = MoqPeerNode::new(MoqPeerConfig::default()); diff --git a/crates/nodes/src/transport/moq/peer.rs b/crates/nodes/src/transport/moq/peer.rs index e28965fe..750c6714 100644 --- a/crates/nodes/src/transport/moq/peer.rs +++ b/crates/nodes/src/transport/moq/peer.rs @@ -5,24 +5,29 @@ //! MoQ Peer Node - bidirectional server that accepts WebTransport connections //! //! This node supports a publish/subscribe architecture: -//! - One publisher connects to `{gateway_path}/input` to send audio -//! - Multiple subscribers connect to `{gateway_path}/output` to receive processed audio +//! - One publisher connects to `{gateway_path}/input` to send media +//! - Multiple subscribers connect to `{gateway_path}/output` to receive processed media +//! +//! Input and output pins are type-agnostic: both `in` and `in_1` accept any +//! supported encoded media type (Opus audio, VP9 video). The actual media kind +//! flowing through each pin is determined at runtime from `NodeContext::input_types`. use async_trait::async_trait; use bytes::Buf; -use moq_lite::coding::Decode; use schemars::JsonSchema; use serde::Deserialize; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; use streamkit_core::timing::MediaClock; -use streamkit_core::types::{Packet, PacketType}; +use streamkit_core::types::{ + AudioCodec, EncodedAudioFormat, EncodedVideoFormat, Packet, PacketType, VideoCodec, +}; use streamkit_core::{ state_helpers, stats::NodeStatsTracker, InputPin, NodeContext, OutputPin, PinCardinality, ProcessorNode, StreamKitError, }; -use tokio::sync::{broadcast, mpsc, OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{broadcast, mpsc, watch, OwnedSemaphorePermit, Semaphore}; /// Capacity for the broadcast channel (subscribers) const SUBSCRIBER_BROADCAST_CAPACITY: usize = 256; @@ -35,10 +40,32 @@ struct NodeStatsDelta { errored: u64, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum MediaKind { + Audio, + Video, +} + #[derive(Clone, Debug)] struct BroadcastFrame { data: bytes::Bytes, duration_us: Option, + kind: MediaKind, + keyframe: bool, +} + +/// Media type state shared from the main select loop to subscriber tasks via a +/// [`watch`] channel. Subscribers wait for `resolved == true` before building +/// the MoQ catalog so that dynamic pipelines (where `input_types` is empty) +/// don't advertise an empty catalog. +#[derive(Clone, Debug)] +struct MediaTypeState { + has_audio: bool, + has_video: bool, + /// `true` once the media kind of every connected input pin has been + /// determined — either from `NodeContext::input_types` (static pipelines) + /// or from the first packet on each pin (dynamic pipelines). + resolved: bool, } /// Result of processing a single frame @@ -51,12 +78,43 @@ enum FrameResult { Shutdown, } +/// Outcome of a publisher track processing loop. +/// +/// Distinguishes a transient publisher-side cancellation (which the caller may +/// recover from by re-subscribing) from clean completion and fatal errors. +enum TrackExit { + /// Track closed cleanly, stream ended, or shutdown was requested. + /// The caller should not retry. + Finished, + /// The publisher cancelled the subscription (`moq_lite::Error::Cancel`). + /// + /// This typically happens when the browser's `@moq/hang` publish pipeline + /// transiently tears down the track producer — e.g. when `camera.source` + /// flaps to `undefined` during the permission-grant → device-enumeration + /// cascade. The catalog still advertises the track and the browser's + /// `Broadcast#runBroadcast` loop will happily accept a fresh subscription, + /// so the caller may re-subscribe after a short backoff. + Cancelled, + /// A non-recoverable error occurred. Propagate up. + Error(StreamKitError), +} + #[derive(Debug)] enum PublisherEvent { Connected { path: String }, Disconnected { path: String, error: Option }, } +/// Media and output configuration shared across subscriber-related functions. +struct SubscriberMediaConfig { + has_video: bool, + has_audio: bool, + video_width: u32, + video_height: u32, + output_group_duration_ms: u64, + output_initial_delay_ms: u64, +} + struct BidirectionalTaskConfig { input_broadcast: String, output_broadcast: String, @@ -67,9 +125,11 @@ struct BidirectionalTaskConfig { publisher_slot: Arc, publisher_events: mpsc::UnboundedSender, subscriber_count: Arc, - output_group_duration_ms: u64, - output_initial_delay_ms: u64, stats_delta_tx: mpsc::Sender, + media: SubscriberMediaConfig, + media_state_rx: watch::Receiver, + audio_output_pin: &'static str, + video_output_pin: &'static str, } struct PublisherReceiveLoopWithSlotConfig { @@ -80,6 +140,8 @@ struct PublisherReceiveLoopWithSlotConfig { publisher_events: mpsc::UnboundedSender, publisher_path: String, stats_delta_tx: mpsc::Sender, + audio_output_pin: &'static str, + video_output_pin: &'static str, } fn normalize_gateway_path(path: &str) -> String { @@ -122,6 +184,14 @@ pub struct MoqPeerConfig { /// /// Default: 0 (no added delay). pub output_initial_delay_ms: u64, + /// Video width in pixels for the MoQ catalog. + /// Used to advertise the video resolution to subscribers. + /// Default: 640. + pub video_width: u32, + /// Video height in pixels for the MoQ catalog. + /// Used to advertise the video resolution to subscribers. + /// Default: 480. + pub video_height: u32, } impl Default for MoqPeerConfig { @@ -133,13 +203,71 @@ impl Default for MoqPeerConfig { allow_reconnect: false, output_group_duration_ms: 40, output_initial_delay_ms: 0, + video_width: 640, + video_height: 480, + } + } +} + +/// Infer the [`MediaKind`] from a [`PacketType`]. +/// +/// Returns `Some(Audio)` for encoded audio, `Some(Video)` for encoded video, +/// and `None` for anything else. +const fn media_kind_for_packet_type(pt: &PacketType) -> Option { + match pt { + PacketType::EncodedAudio(_) => Some(MediaKind::Audio), + PacketType::EncodedVideo(_) => Some(MediaKind::Video), + _ => None, + } +} + +/// Infer [`MediaKind`] from a packet's `content_type` field. +/// +/// VP9-encoded packets carry `content_type: Some("video/vp9")`, so any +/// content type starting with `"video/"` is classified as video. +/// Audio packets (Opus) typically have `content_type: None`. +/// +/// # Panics (debug only) +/// +/// Debug-asserts that the `content_type` is either `None` (audio) or starts +/// with `"audio/"` or `"video/"`. This catches future encoders that forget +/// to set the field. +fn infer_kind_from_packet(packet: &Packet) -> MediaKind { + if let Packet::Binary { content_type, .. } = packet { + if let Some(ct) = content_type.as_deref() { + if ct.starts_with("video/") { + return MediaKind::Video; + } + debug_assert!( + ct.starts_with("audio/"), + "unexpected content_type {ct:?} — expected \"audio/…\" or \"video/…\"" + ); + } else { + tracing::trace!("packet has no content_type, assuming audio"); } } + MediaKind::Audio +} + +/// Build a [`BroadcastFrame`] from a pipeline [`Packet`] and the inferred +/// [`MediaKind`] for the pin it arrived on. +fn make_broadcast_frame(packet: Packet, kind: MediaKind) -> Option { + if let Packet::Binary { data, metadata, .. } = packet { + let duration_us = super::constants::packet_duration_us(metadata.as_ref()); + let keyframe = if kind == MediaKind::Video { + metadata.as_ref().and_then(|m| m.keyframe).unwrap_or(false) + } else { + false + }; + Some(BroadcastFrame { data, duration_us, kind, keyframe }) + } else { + None + } } /// A MoQ server node that supports one publisher and multiple subscribers. -/// - Publisher connects to `{gateway_path}/input` and sends audio to the pipeline -/// - Subscribers connect to `{gateway_path}/output` and receive processed audio +/// - Publisher connects to `{gateway_path}/input` and sends media to the pipeline +/// - Subscribers connect to `{gateway_path}/output` and receive processed media pub struct MoqPeerNode { config: MoqPeerConfig, } @@ -151,21 +279,49 @@ impl MoqPeerNode { } #[async_trait] +#[allow(clippy::too_many_lines)] impl ProcessorNode for MoqPeerNode { fn input_pins(&self) -> Vec { - vec![InputPin { - name: "in".to_string(), - accepts_types: vec![PacketType::OpusAudio], - cardinality: PinCardinality::One, - }] + let accepted_types = vec![ + PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), + PacketType::EncodedVideo(EncodedVideoFormat { + codec: VideoCodec::Vp9, + bitstream_format: None, + codec_private: None, + profile: None, + level: None, + }), + ]; + vec![ + InputPin { + name: "in".to_string(), + accepts_types: accepted_types.clone(), + cardinality: PinCardinality::One, + }, + InputPin { + name: "in_1".to_string(), + accepts_types: accepted_types, + cardinality: PinCardinality::One, + }, + ] } fn output_pins(&self) -> Vec { - vec![OutputPin { - name: "out".to_string(), - produces_type: PacketType::OpusAudio, - cardinality: PinCardinality::Broadcast, - }] + vec![ + OutputPin { + name: "out".to_string(), + produces_type: PacketType::Any, + cardinality: PinCardinality::Broadcast, + }, + OutputPin { + name: "out_1".to_string(), + produces_type: PacketType::Any, + cardinality: PinCardinality::Broadcast, + }, + ] } async fn run(self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { @@ -235,8 +391,79 @@ impl ProcessorNode for MoqPeerNode { StreamKitError::Runtime(err) })?; - // Take ownership of pipeline input channel - let mut pipeline_input_rx = context.take_input("in")?; + // Take ownership of pipeline input channels. + // Both pins accept any encoded media type — the actual kind is + // determined at runtime from `input_types`. + let mut pin_0_rx = context.take_input("in").ok(); + let mut pin_1_rx = context.take_input("in_1").ok(); + + // Try to get type info from the graph builder (static pipelines). + // For dynamic pipelines `input_types` is empty — we handle that below. + let mut pin_0_kind = context.input_types.get("in").and_then(media_kind_for_packet_type); + let mut pin_1_kind = context.input_types.get("in_1").and_then(media_kind_for_packet_type); + + if pin_0_rx.is_none() && pin_1_rx.is_none() { + return Err(StreamKitError::Configuration( + "MoQ peer requires at least one input pin (\"in\" or \"in_1\")".to_string(), + )); + } + + let mut has_audio = + pin_0_kind == Some(MediaKind::Audio) || pin_1_kind == Some(MediaKind::Audio); + let mut has_video = + pin_0_kind == Some(MediaKind::Video) || pin_1_kind == Some(MediaKind::Video); + + // NOTE: In dynamic pipelines the engine creates receivers for ALL + // declared input pins, even those that are never wired to an upstream + // node. We therefore cannot use `pin_rx.is_some()` to decide + // connectivity, and `input_types` is empty so pin kinds are unknown. + // + // We optimistically advertise both audio and video so that the MoQ + // catalog is published immediately when a subscriber connects. Tracks + // that never receive data are harmless — the subscriber simply gets no + // frames on them. This avoids a race where the browser subscribes to + // `catalog.json` before the catalog track has been created (which would + // return "not found" and prevent the watch path from going live). + let dynamic_mode = + context.input_types.is_empty() && pin_0_kind.is_none() && pin_1_kind.is_none(); + if dynamic_mode { + has_audio = true; + has_video = true; + } + let types_resolved = if dynamic_mode { + // Optimistically resolved — both types are advertised. + true + } else { + let pin_0_connected = pin_0_rx.is_some(); + let pin_1_connected = pin_1_rx.is_some(); + (pin_0_kind.is_some() || !pin_0_connected) && (pin_1_kind.is_some() || !pin_1_connected) + }; + + let (media_state_tx, media_state_rx) = + watch::channel(MediaTypeState { has_audio, has_video, resolved: types_resolved }); + + // Symmetric output pin mapping: in ↔ out, in_1 ↔ out_1. + // When pin types are known the mapping is exact; otherwise we fall + // back to the convention audio → "out", video → "out_1". + let audio_output_pin: &str = match (pin_0_kind, pin_1_kind) { + (Some(MediaKind::Audio), _) => "out", + (_, Some(MediaKind::Audio)) | (Some(MediaKind::Video), _) => "out_1", + _ => "out", + }; + let video_output_pin: &str = match (pin_0_kind, pin_1_kind) { + (Some(MediaKind::Video), _) => "out", + _ => "out_1", + }; + + tracing::info!( + has_audio, + has_video, + ?pin_0_kind, + ?pin_1_kind, + audio_output_pin, + video_output_pin, + "MoQ peer input pins (types inferred at runtime)" + ); // Create broadcast channel for fanning out to subscribers let (subscriber_broadcast_tx, _) = @@ -304,14 +531,23 @@ impl ProcessorNode for MoqPeerNode { output_sender: context.output_sender.clone(), broadcast_rx, shutdown_rx: shutdown_tx.subscribe(), - publisher_slot: publisher_slot.clone(), - publisher_events: publisher_events_tx.clone(), - subscriber_count: sub_count, - output_group_duration_ms: self.config.output_group_duration_ms, - output_initial_delay_ms: self.config.output_initial_delay_ms, - stats_delta_tx: stats_delta_tx.clone(), - }, - ).await { + publisher_slot: publisher_slot.clone(), + publisher_events: publisher_events_tx.clone(), + subscriber_count: sub_count, + stats_delta_tx: stats_delta_tx.clone(), + media: SubscriberMediaConfig { + has_video, + has_audio, + video_width: self.config.video_width, + video_height: self.config.video_height, + output_group_duration_ms: self.config.output_group_duration_ms, + output_initial_delay_ms: self.config.output_initial_delay_ms, + }, + media_state_rx: media_state_rx.clone(), + audio_output_pin, + video_output_pin, + }, + ).await { Ok(_handle) => { let count = subscriber_count.fetch_add(1, Ordering::SeqCst) + 1; tracing::info!("Peer connected (total: {})", count); @@ -363,6 +599,8 @@ impl ProcessorNode for MoqPeerNode { shutdown_tx.subscribe(), publisher_events_tx.clone(), stats_delta_tx.clone(), + audio_output_pin, + video_output_pin, ).await { Ok(_handle) => { tracing::info!("Publisher connected and streaming"); @@ -406,9 +644,16 @@ impl ProcessorNode for MoqPeerNode { broadcast_rx, shutdown_tx.subscribe(), sub_count, - self.config.output_group_duration_ms, - self.config.output_initial_delay_ms, stats_delta_tx.clone(), + SubscriberMediaConfig { + has_video, + has_audio, + video_width: self.config.video_width, + video_height: self.config.video_height, + output_group_duration_ms: self.config.output_group_duration_ms, + output_initial_delay_ms: self.config.output_initial_delay_ms, + }, + media_state_rx.clone(), ).await { Ok(_handle) => { let count = subscriber_count.fetch_add(1, Ordering::SeqCst) + 1; @@ -420,20 +665,91 @@ impl ProcessorNode for MoqPeerNode { } } - // Forward packets from pipeline to broadcast channel - packet = pipeline_input_rx.recv() => { - if let Some(packet) = packet { - if let Packet::Binary { data, metadata, .. } = packet { + // Forward packets from pin "in" to broadcast channel + result = async { + if let Some(ref mut rx) = pin_0_rx { rx.recv().await } else { std::future::pending().await } + } => { + if let Some(packet) = result { + // Lazily determine kind on first packet when input_types + // is unavailable (dynamic pipelines). + let kind = pin_0_kind.unwrap_or_else(|| { + let k = infer_kind_from_packet(&packet); + pin_0_kind = Some(k); + match k { + MediaKind::Audio => has_audio = true, + MediaKind::Video => has_video = true, + } + // In dynamic mode, resolve immediately on first + // packet — the subscriber applies a grace period. + let resolved = if dynamic_mode { + true + } else { + let pin_1_connected = pin_1_rx.is_some(); + pin_1_kind.is_some() || !pin_1_connected + }; + let _ = media_state_tx.send(MediaTypeState { + has_audio, + has_video, + resolved, + }); + tracing::info!(?k, resolved, "pin \"in\": media kind inferred from first packet"); + k + }); + if let Some(frame) = make_broadcast_frame(packet, kind) { stats_tracker.received(); - // Broadcast to all subscribers (ignore if no receivers) - let duration_us = super::constants::packet_duration_us(metadata.as_ref()); - let _ = subscriber_broadcast_tx.send(BroadcastFrame { data, duration_us }); + let _ = subscriber_broadcast_tx.send(frame); stats_tracker.sent(); stats_tracker.maybe_send(); } } else { - tracing::info!("Pipeline input closed"); - break Ok(()); + tracing::info!("Pipeline input pin \"in\" closed"); + pin_0_rx = None; + if pin_1_rx.is_none() { + tracing::info!("All pipeline inputs closed, shutting down"); + break Ok(()); + } + } + } + + // Forward packets from pin "in_1" to broadcast channel + result = async { + if let Some(ref mut rx) = pin_1_rx { rx.recv().await } else { std::future::pending().await } + } => { + if let Some(packet) = result { + let kind = pin_1_kind.unwrap_or_else(|| { + let k = infer_kind_from_packet(&packet); + pin_1_kind = Some(k); + match k { + MediaKind::Audio => has_audio = true, + MediaKind::Video => has_video = true, + } + let resolved = if dynamic_mode { + true + } else { + let pin_0_connected = pin_0_rx.is_some(); + pin_0_kind.is_some() || !pin_0_connected + }; + let _ = media_state_tx.send(MediaTypeState { + has_audio, + has_video, + resolved, + }); + tracing::info!(?k, resolved, "pin \"in_1\": media kind inferred from first packet"); + k + }); + if let Some(frame) = make_broadcast_frame(packet, kind) { + stats_tracker.received(); + let _ = subscriber_broadcast_tx.send(frame); + stats_tracker.sent(); + stats_tracker.maybe_send(); + } + } else { + tracing::info!("Pipeline input pin \"in_1\" closed"); + pin_1_rx = None; + if pin_0_rx.is_none() { + tracing::info!("All pipeline inputs closed, shutting down"); + break Ok(()); + } } } @@ -520,7 +836,9 @@ impl ProcessorNode for MoqPeerNode { } impl MoqPeerNode { - /// Start a task to handle publisher connection (receives audio from client) + /// Start a task to handle publisher connection (receives media from client) + // Pin-specific output routing requires per-pin parameters; bundling into a config struct is a future cleanup. + #[allow(clippy::too_many_arguments)] async fn start_publisher_task_with_permit( moq_connection: streamkit_core::moq_gateway::MoqConnection, permit: OwnedSemaphorePermit, @@ -529,6 +847,8 @@ impl MoqPeerNode { mut shutdown_rx: broadcast::Receiver<()>, publisher_events: mpsc::UnboundedSender, stats_delta_tx: mpsc::Sender, + audio_output_pin: &'static str, + video_output_pin: &'static str, ) -> Result>, StreamKitError> { let path = moq_connection.path.clone(); @@ -550,7 +870,7 @@ impl MoqPeerNode { // Accept MoQ session (publisher only sends, no server publish needed) let session = request .with_consume(client_publish_origin) - .accept() + .ok() .await .map_err(|e| StreamKitError::Runtime(format!("Failed to accept session: {e}")))?; @@ -564,6 +884,8 @@ impl MoqPeerNode { output_sender, &mut shutdown_rx, stats_delta_tx, + audio_output_pin, + video_output_pin, ) .await; @@ -607,7 +929,7 @@ impl MoqPeerNode { let session = request .with_publish(server_publish_origin.consume()) .with_consume(client_publish_origin) - .accept() + .ok() .await .map_err(|e| StreamKitError::Runtime(format!("Failed to accept session: {e}")))?; @@ -629,6 +951,8 @@ impl MoqPeerNode { publisher_events: config.publisher_events, publisher_path: path.clone(), stats_delta_tx: publisher_stats_delta_tx, + audio_output_pin: config.audio_output_pin, + video_output_pin: config.video_output_pin, }, &mut publisher_shutdown_rx, ) @@ -642,9 +966,9 @@ impl MoqPeerNode { config.node_id.clone(), config.broadcast_rx, &mut subscriber_shutdown_rx, - config.output_group_duration_ms, - config.output_initial_delay_ms, subscriber_stats_delta_tx, + config.media, + config.media_state_rx, ) .await }; @@ -658,7 +982,7 @@ impl MoqPeerNode { tracing::warn!(path = %path, error = %e, "Peer subscriber task error"); } - let count = config.subscriber_count.fetch_sub(1, Ordering::SeqCst) - 1; + let count = config.subscriber_count.fetch_sub(1, Ordering::SeqCst).saturating_sub(1); tracing::info!(path = %path, "Peer disconnected (remaining: {})", count); drop(session); @@ -699,32 +1023,14 @@ impl MoqPeerNode { .publisher_events .send(PublisherEvent::Connected { path: config.publisher_path.clone() }); - let result = async { - let Some((audio_track_name, audio_priority)) = - Self::wait_for_catalog_with_audio(&broadcast_consumer, shutdown_rx).await? - else { - return Ok(()); - }; - - tracing::info!( - path = %config.publisher_path, - "Subscribing to peer publisher audio track: {}", - audio_track_name - ); - - let track_consumer = broadcast_consumer.subscribe_track(&moq_lite::Track { - name: audio_track_name, - priority: audio_priority, - }); - - Self::process_publisher_frames( - track_consumer, - config.output_sender, - shutdown_rx, - &config.stats_delta_tx, - ) - .await - } + let result = Self::watch_catalog_and_process( + &broadcast_consumer, + config.output_sender, + shutdown_rx, + &config.stats_delta_tx, + config.audio_output_pin, + config.video_output_pin, + ) .await; drop(permit); @@ -736,13 +1042,15 @@ impl MoqPeerNode { result } - /// Publisher receive loop - receives audio from client and sends to pipeline + /// Publisher receive loop - receives audio/video from client and sends to pipeline async fn publisher_receive_loop( subscribe: moq_lite::OriginConsumer, broadcast_name: String, output_sender: streamkit_core::OutputSender, shutdown_rx: &mut broadcast::Receiver<()>, stats_delta_tx: mpsc::Sender, + audio_output_pin: &'static str, + video_output_pin: &'static str, ) -> Result<(), StreamKitError> { tracing::info!("Waiting for publisher to announce broadcast: {}", broadcast_name); @@ -753,21 +1061,17 @@ impl MoqPeerNode { return Ok(()); // Shutdown requested }; - // Wait for catalog with audio track info - let Some((audio_track_name, audio_priority)) = - Self::wait_for_catalog_with_audio(&broadcast_consumer, shutdown_rx).await? - else { - return Ok(()); // Shutdown requested - }; - - tracing::info!("Subscribing to publisher audio track: {}", audio_track_name); - - let track_consumer = broadcast_consumer - .subscribe_track(&moq_lite::Track { name: audio_track_name, priority: audio_priority }); - - // Process incoming frames - Self::process_publisher_frames(track_consumer, output_sender, shutdown_rx, &stats_delta_tx) - .await + // Watch catalog and process tracks as they appear (handles incremental + // permission grants where mic/camera become available at different times) + Self::watch_catalog_and_process( + &broadcast_consumer, + output_sender, + shutdown_rx, + &stats_delta_tx, + audio_output_pin, + video_output_pin, + ) + .await } /// Wait for the publisher to announce the expected broadcast @@ -802,48 +1106,321 @@ impl MoqPeerNode { } } - /// Wait for the catalog to contain audio track information - async fn wait_for_catalog_with_audio( + /// Unwrap a catalog read result, returning `Some(catalog)` on success, + /// or `None` when the caller should break (closed / timeout / error). + fn unwrap_catalog_result( + result: Result, E>, tokio::time::error::Elapsed>, + ) -> Option { + match result { + Ok(Ok(Some(catalog))) => Some(catalog), + Ok(Ok(None)) => { + tracing::info!("Catalog track closed"); + None + }, + Ok(Err(e)) => { + tracing::warn!("Error reading catalog: {}", e); + None + }, + Err(_) => { + tracing::info!("Catalog timeout — proceeding with discovered tracks"); + None + }, + } + } + + /// Watch the catalog continuously and process publisher tracks as they appear. + /// + /// Instead of waiting for all tracks upfront, this subscribes to and starts + /// processing each track as soon as it appears in the catalog. This handles + /// the common case where the browser grants mic and camera permissions at + /// different times, causing the hang library to publish incremental catalog + /// updates (e.g., audio-only first, then audio+video). + async fn watch_catalog_and_process( broadcast_consumer: &moq_lite::BroadcastConsumer, + output_sender: streamkit_core::OutputSender, shutdown_rx: &mut broadcast::Receiver<()>, - ) -> Result, StreamKitError> { + stats_delta_tx: &mpsc::Sender, + audio_output_pin: &'static str, + video_output_pin: &'static str, + ) -> Result<(), StreamKitError> { let catalog_track = - broadcast_consumer.subscribe_track(&hang::catalog::Catalog::default_track()); + broadcast_consumer.subscribe_track(&hang::catalog::Catalog::default_track()).map_err( + |e| StreamKitError::Runtime(format!("Failed to subscribe to catalog track: {e}")), + )?; let mut catalog_consumer = hang::catalog::CatalogConsumer::new(catalog_track); + let mut audio_handle: Option>> = None; + let mut video_handle: Option>> = None; + + // Monitor the catalog for new tracks, subscribing to each as it appears loop { tokio::select! { - catalog_result = tokio::time::timeout(Duration::from_secs(10), catalog_consumer.next()) => { - let catalog = catalog_result - .map_err(|_| StreamKitError::Runtime("Timeout waiting for catalog".to_string()))? - .map_err(|e| StreamKitError::Runtime(format!("Failed to read catalog: {e}")))? - .ok_or_else(|| StreamKitError::Runtime("Catalog track closed".to_string()))?; + biased; + _ = shutdown_rx.recv() => { + tracing::info!("Catalog watch shutting down"); + break; + } + catalog_result = tokio::time::timeout(Duration::from_secs(30), catalog_consumer.next()) => { + let Some(catalog) = Self::unwrap_catalog_result(catalog_result) else { + break; + }; - tracing::info!("Received catalog from publisher: audio={:?}", catalog.audio); + tracing::info!( + "Received catalog from publisher: audio={:?}, video renditions={}", + catalog.audio, catalog.video.renditions.len() + ); - { - let audio = &catalog.audio; - if let Some(track_name) = audio.renditions.keys().next() { + // Start audio processing if a new audio track appeared + if audio_handle.is_none() { + if let Some(track_name) = catalog.audio.renditions.keys().next() { tracing::info!("Found audio track in catalog: {}", track_name); - return Ok(Some((track_name.clone(), 2))); + audio_handle = Some(Self::spawn_track_processor( + broadcast_consumer, track_name, audio_output_pin, + &output_sender, shutdown_rx, stats_delta_tx, + )); + } + } + + // Start video processing if a new video track appeared + if video_handle.is_none() { + if let Some(track_name) = catalog.video.renditions.keys().next() { + tracing::info!("Found video track in catalog: {}", track_name); + video_handle = Some(Self::spawn_track_processor( + broadcast_consumer, track_name, video_output_pin, + &output_sender, shutdown_rx, stats_delta_tx, + )); } } - tracing::debug!("Catalog has no audio yet, waiting for update..."); + + // Stop watching catalog once both tracks are subscribed + if audio_handle.is_some() && video_handle.is_some() { + tracing::info!("All tracks discovered, stopping catalog watch"); + break; + } } - _ = shutdown_rx.recv() => { - return Ok(None); + } + } + + // Wait for all active processing tasks to finish + Self::await_track_tasks(audio_handle, video_handle).await + } + + /// Spawn a task that processes frames from a single publisher track. + /// + /// The subscription is created *inside* the spawned task and wrapped in a + /// bounded retry loop. If the publisher transiently cancels the track + /// (see [`TrackExit::Cancelled`]), we back off briefly and re-subscribe + /// via `BroadcastConsumer::subscribe_track`, which creates a fresh + /// `TrackProducer`/`TrackConsumer` pair (the old producer is evicted from + /// moq-lite's dedup map once unused). This makes the pipeline resilient + /// to brief client-side track flaps regardless of `@moq/hang` version. + fn spawn_track_processor( + broadcast_consumer: &moq_lite::BroadcastConsumer, + track_name: &str, + output_pin: &'static str, + output_sender: &streamkit_core::OutputSender, + shutdown_rx: &broadcast::Receiver<()>, + stats_delta_tx: &mpsc::Sender, + ) -> tokio::task::JoinHandle> { + // How many times to re-subscribe after a publisher-side cancellation + // before giving up. The browser's camera-source flap during the + // permission-grant → device-enumeration cascade can span ~300ms+, + // so we use exponential backoff (100, 200, 400, 800…) to cover it. + const MAX_RESUBSCRIBE_ATTEMPTS: u32 = 10; + const RESUBSCRIBE_INITIAL_BACKOFF: Duration = Duration::from_millis(100); + + // BroadcastConsumer is Clone (Arc-like state + watch::Receiver + + // async_channel::Sender). Cloning into the task lets us re-subscribe + // after a cancellation without holding a reference across the spawn. + let broadcast = broadcast_consumer.clone(); + let track = moq_lite::Track { name: track_name.to_string(), priority: 2 }; + let sender = output_sender.clone(); + let mut task_shutdown = shutdown_rx.resubscribe(); + let stats = stats_delta_tx.clone(); + let pin_name = output_pin; + + tokio::spawn(async move { + tracing::info!(output_pin = pin_name, track = %track.name, "Track processor task started"); + + let mut attempt: u32 = 0; + loop { + let consumer = broadcast.subscribe_track(&track).map_err(|e| { + StreamKitError::Runtime(format!( + "Failed to subscribe to track '{}': {e}", + track.name + )) + })?; + let exit = Self::process_publisher_frames( + consumer, + sender.clone(), + pin_name, + &mut task_shutdown, + &stats, + ) + .await; + + match exit { + TrackExit::Finished => { + tracing::info!( + output_pin = pin_name, + "Track processor task finished normally" + ); + return Ok(()); + }, + TrackExit::Error(e) => { + tracing::warn!( + output_pin = pin_name, + error = %e, + "Track processor task finished with error" + ); + return Err(e); + }, + TrackExit::Cancelled => { + attempt += 1; + if attempt > MAX_RESUBSCRIBE_ATTEMPTS { + tracing::warn!( + output_pin = pin_name, + attempts = attempt, + "Publisher track cancelled; retry budget exhausted" + ); + return Err(StreamKitError::Runtime(format!( + "publisher track '{}' cancelled {} times; giving up", + track.name, attempt + ))); + } + let backoff = + RESUBSCRIBE_INITIAL_BACKOFF * 2u32.saturating_pow(attempt - 1); + tracing::info!( + output_pin = pin_name, + attempt, + max = MAX_RESUBSCRIBE_ATTEMPTS, + backoff_ms = backoff.as_millis(), + "Publisher track cancelled; re-subscribing after backoff" + ); + // Yield to shutdown during backoff so we don't delay + // teardown if the pipeline is stopping. + tokio::select! { + biased; + _ = task_shutdown.recv() => { + tracing::info!(output_pin = pin_name, "Shutdown during re-subscribe backoff"); + return Ok(()); + } + () = tokio::time::sleep(backoff) => {} + } + }, } } + }) + } + + /// Wait for spawned track processing tasks to complete. + /// + /// Uses `tokio::select!` so that if either task exits early (e.g. the video + /// track is closed by the publisher), the remaining task can continue + /// running independently while the overall function waits for it. + #[allow(clippy::cognitive_complexity)] + async fn await_track_tasks( + audio_handle: Option>>, + video_handle: Option>>, + ) -> Result<(), StreamKitError> { + if audio_handle.is_none() && video_handle.is_none() { + tracing::warn!("Publisher catalog had no audio or video tracks"); + return Ok(()); } + + let mut audio_done = audio_handle.is_none(); + let mut video_done = video_handle.is_none(); + + // Wrap the handles so we can select! over them concurrently + let audio_fut = async { + match audio_handle { + Some(h) => Some(h.await), + None => { + // No audio task — pend forever so only video is selected + std::future::pending::< + Option, tokio::task::JoinError>>, + >() + .await + }, + } + }; + let video_fut = async { + match video_handle { + Some(h) => Some(h.await), + None => { + std::future::pending::< + Option, tokio::task::JoinError>>, + >() + .await + }, + } + }; + + tokio::pin!(audio_fut); + tokio::pin!(video_fut); + + let mut first_error: Option = None; + + while !audio_done || !video_done { + tokio::select! { + result = &mut audio_fut, if !audio_done => { + audio_done = true; + if let Some(join_result) = result { + match join_result { + Err(e) => { + tracing::warn!("Audio task panicked: {e}"); + if first_error.is_none() { + first_error = Some(StreamKitError::Runtime(format!("Audio task panicked: {e}"))); + } + } + Ok(Err(e)) => { + tracing::warn!("Audio task error: {e}"); + if first_error.is_none() { + first_error = Some(e); + } + } + Ok(Ok(())) => tracing::info!("Audio processing task completed"), + } + } + } + result = &mut video_fut, if !video_done => { + video_done = true; + if let Some(join_result) = result { + match join_result { + Err(e) => { + tracing::warn!("Video task panicked: {e}"); + if first_error.is_none() { + first_error = Some(StreamKitError::Runtime(format!("Video task panicked: {e}"))); + } + } + Ok(Err(e)) => { + tracing::warn!("Video task error: {e}"); + if first_error.is_none() { + first_error = Some(e); + } + } + Ok(Ok(())) => tracing::info!("Video processing task completed"), + } + } + } + } + } + + first_error.map_or(Ok(()), Err) } - /// Process incoming frames from the publisher and forward to the pipeline + /// Process incoming frames from the publisher and forward to the pipeline. + /// + /// Returns [`TrackExit`] so the caller can distinguish a transient + /// publisher-side cancellation (retryable via re-subscribe) from clean + /// completion and fatal errors. async fn process_publisher_frames( mut track_consumer: moq_lite::TrackConsumer, mut output_sender: streamkit_core::OutputSender, + output_pin: &str, shutdown_rx: &mut broadcast::Receiver<()>, stats_delta_tx: &mpsc::Sender, - ) -> Result<(), StreamKitError> { + ) -> TrackExit { let mut frame_count = 0u64; let mut last_log = std::time::Instant::now(); let mut current_group: Option = None; @@ -851,9 +1428,15 @@ impl MoqPeerNode { loop { // Get a group if we don't have one if current_group.is_none() { - match Self::get_next_group(&mut track_consumer, shutdown_rx).await? { - Some(group) => current_group = Some(group), - None => return Ok(()), // Stream ended or shutdown + match Self::get_next_group(&mut track_consumer, shutdown_rx, output_pin).await { + Ok(Some(group)) => current_group = Some(group), + Ok(None) => return TrackExit::Finished, // Stream ended or shutdown + Err(moq_lite::Error::Cancel) => return TrackExit::Cancelled, + Err(e) => { + return TrackExit::Error(StreamKitError::Runtime(format!( + "Error getting group: {e}" + ))); + }, } } @@ -862,40 +1445,55 @@ impl MoqPeerNode { match Self::process_frame_from_group( group, &mut output_sender, + output_pin, &mut frame_count, &mut last_log, shutdown_rx, stats_delta_tx, ) - .await? + .await { - FrameResult::Continue => {}, - FrameResult::GroupExhausted => current_group = None, - FrameResult::Shutdown => return Ok(()), + Ok(FrameResult::Continue) => {}, + Ok(FrameResult::GroupExhausted) => current_group = None, + Ok(FrameResult::Shutdown) => return TrackExit::Finished, + Err(e) => return TrackExit::Error(e), } } } } - /// Get the next group from the track consumer + /// Get the next group from the track consumer. + /// + /// Surfaces the raw [`moq_lite::Error`] so the caller can distinguish + /// [`moq_lite::Error::Cancel`] (publisher dropped the track producer — + /// retryable) from other failures. The `tracing::warn!` here will still + /// fire for cancellations; that's intentional since they're unexpected + /// in steady state even if we recover. async fn get_next_group( track_consumer: &mut moq_lite::TrackConsumer, shutdown_rx: &mut broadcast::Receiver<()>, - ) -> Result, StreamKitError> { + output_pin: &str, + ) -> Result, moq_lite::Error> { tokio::select! { biased; group_result = track_consumer.next_group() => { match group_result { - Ok(Some(group)) => Ok(Some(group)), + Ok(Some(group)) => { + tracing::debug!(output_pin, "Got next group from publisher"); + Ok(Some(group)) + } Ok(None) => { - tracing::info!("Publisher stream ended"); + tracing::info!(output_pin, "Publisher stream ended (next_group returned None)"); Ok(None) } - Err(e) => Err(StreamKitError::Runtime(format!("Error getting group: {e}"))), + Err(e) => { + tracing::warn!(output_pin, error = %e, "Error getting group from publisher"); + Err(e) + } } } _ = shutdown_rx.recv() => { - tracing::info!("Publisher receive loop shutting down"); + tracing::info!(output_pin, "Publisher receive loop shutting down (shutdown signal)"); Ok(None) } } @@ -905,6 +1503,7 @@ impl MoqPeerNode { async fn process_frame_from_group( group: &mut moq_lite::GroupConsumer, output_sender: &mut streamkit_core::OutputSender, + output_pin: &str, frame_count: &mut u64, last_log: &mut std::time::Instant, shutdown_rx: &mut broadcast::Receiver<()>, @@ -923,9 +1522,9 @@ impl MoqPeerNode { *last_log = std::time::Instant::now(); } - // Skip timestamp header (varint encoded u64 microseconds) + // Skip timestamp header (varint encoded timestamp in microseconds) // The hang protocol encodes timestamp at the start of each frame - if let Err(e) = u64::decode(&mut payload, moq_lite::lite::Version::Draft02) { + if let Err(e) = hang::container::Timestamp::decode(&mut payload) { tracing::warn!("Failed to decode frame timestamp: {e}"); let _ = stats_delta_tx .try_send(NodeStatsDelta { received: 1, discarded: 1, ..Default::default() }); @@ -939,8 +1538,8 @@ impl MoqPeerNode { metadata: None, }; - if output_sender.send("out", packet).await.is_err() { - tracing::debug!("Output channel closed"); + if output_sender.send(output_pin, packet).await.is_err() { + tracing::info!(output_pin, "Output channel closed for pin"); let _ = stats_delta_tx .try_send(NodeStatsDelta { received: 1, ..Default::default() }); return Ok(FrameResult::Shutdown); @@ -950,14 +1549,14 @@ impl MoqPeerNode { } Ok(None) => Ok(FrameResult::GroupExhausted), Err(e) => { - tracing::warn!("Error reading frame: {e}"); + tracing::warn!(output_pin, "Error reading frame: {e}"); let _ = stats_delta_tx.try_send(NodeStatsDelta { errored: 1, ..Default::default() }); Ok(FrameResult::GroupExhausted) } } } _ = shutdown_rx.recv() => { - tracing::info!("Publisher receive loop shutting down"); + tracing::info!(output_pin, "Publisher receive loop shutting down (frame read)"); Ok(FrameResult::Shutdown) } } @@ -972,9 +1571,9 @@ impl MoqPeerNode { broadcast_rx: broadcast::Receiver, mut shutdown_rx: broadcast::Receiver<()>, subscriber_count: Arc, - output_group_duration_ms: u64, - output_initial_delay_ms: u64, stats_delta_tx: mpsc::Sender, + media: SubscriberMediaConfig, + media_state_rx: watch::Receiver, ) -> Result, StreamKitError> { // Extract the moq-native Request let request = *moq_connection @@ -994,7 +1593,7 @@ impl MoqPeerNode { // Accept MoQ session (subscriber only receives, no client publish needed) let session = request .with_publish(server_publish_origin.consume()) - .accept() + .ok() .await .map_err(|e| StreamKitError::Runtime(format!("Failed to accept session: {e}")))?; @@ -1005,14 +1604,14 @@ impl MoqPeerNode { node_id, broadcast_rx, &mut shutdown_rx, - output_group_duration_ms, - output_initial_delay_ms, stats_delta_tx, + media, + media_state_rx, ) .await; // Decrement subscriber count - let count = subscriber_count.fetch_sub(1, Ordering::SeqCst) - 1; + let count = subscriber_count.fetch_sub(1, Ordering::SeqCst).saturating_sub(1); tracing::info!("Subscriber disconnected (remaining: {})", count); // Keep session alive until task ends @@ -1027,6 +1626,7 @@ impl MoqPeerNode { } /// Subscriber send loop - receives from broadcast channel and sends to client + // media_state_rx adds a necessary parameter for dynamic media-type resolution. #[allow(clippy::too_many_arguments)] async fn subscriber_send_loop( publish: moq_lite::OriginProducer, @@ -1034,40 +1634,121 @@ impl MoqPeerNode { node_id: String, broadcast_rx: broadcast::Receiver, shutdown_rx: &mut broadcast::Receiver<()>, - output_group_duration_ms: u64, - output_initial_delay_ms: u64, stats_delta_tx: mpsc::Sender, + mut media: SubscriberMediaConfig, + mut media_state_rx: watch::Receiver, ) -> Result<(), StreamKitError> { + // Wait for media types to be resolved before building the catalog. + // For static pipelines `resolved` is true immediately. For dynamic + // pipelines we wait until the first packet on any connected input pin + // has been processed so we know at least one media type. + if !media_state_rx.borrow().resolved { + tracing::info!("Waiting for input pin media types to be resolved..."); + let deadline = tokio::time::Instant::now() + Duration::from_secs(5); + loop { + tokio::select! { + result = media_state_rx.changed() => { + if result.is_err() { break; } + if media_state_rx.borrow().resolved { break; } + } + () = tokio::time::sleep_until(deadline) => { + tracing::warn!("Timed out waiting for media type resolution"); + break; + } + _recv = shutdown_rx.recv() => { + tracing::info!("Shutdown while waiting for media type resolution"); + return Ok(()); + } + } + } + } + + // After resolution, if only partial media types are known, wait a + // brief grace period for additional types. In dynamic pipelines a + // second input pin may receive its first packet shortly after the + // first pin resolved the state. + let needs_grace = { + let state = media_state_rx.borrow(); + state.resolved && !(state.has_audio && state.has_video) + }; + if needs_grace { + let grace = tokio::time::Instant::now() + Duration::from_millis(500); + loop { + tokio::select! { + result = media_state_rx.changed() => { + if result.is_err() { break; } + let both = { + let s = media_state_rx.borrow(); + s.has_audio && s.has_video + }; + if both { break; } + } + () = tokio::time::sleep_until(grace) => { break; } + _recv = shutdown_rx.recv() => { + tracing::info!("Shutdown during media type grace period"); + return Ok(()); + } + } + } + } + + // Apply the resolved media state. + { + let state = media_state_rx.borrow(); + media.has_audio = state.has_audio; + media.has_video = state.has_video; + } + // Setup broadcast and tracks - let (_broadcast_producer, mut track_producer, _catalog_producer) = - Self::setup_subscriber_broadcast(&publish, &broadcast_name)?; + let ( + _broadcast_producer, + mut audio_track_producer, + mut video_track_producer, + _catalog_producer, + ) = Self::setup_subscriber_broadcast(&publish, &broadcast_name, &media)?; - tracing::info!("Published catalog to subscriber"); + tracing::info!( + has_audio = media.has_audio, + has_video = media.has_video, + "Published catalog to subscriber" + ); // Run the send loop let packet_count = Self::run_subscriber_send_loop( - &mut track_producer, + &mut audio_track_producer, + &mut video_track_producer, broadcast_rx, shutdown_rx, - output_group_duration_ms, - output_initial_delay_ms, + media.output_group_duration_ms, + media.output_initial_delay_ms, node_id, broadcast_name, &stats_delta_tx, ) .await?; - track_producer.track.clone().close(); + if let Some(ref mut p) = audio_track_producer { + let _ = p.track.finish(); + } + if let Some(ref mut p) = video_track_producer { + let _ = p.track.finish(); + } tracing::info!("Subscriber task finished after {} packets", packet_count); Ok(()) } - /// Setup broadcast, audio track, and catalog for subscriber + /// Setup broadcast, media tracks, and catalog for subscriber fn setup_subscriber_broadcast( publish: &moq_lite::OriginProducer, broadcast_name: &str, + media: &SubscriberMediaConfig, ) -> Result< - (moq_lite::BroadcastProducer, hang::container::OrderedProducer, moq_lite::TrackProducer), + ( + moq_lite::BroadcastProducer, + Option, + Option, + moq_lite::TrackProducer, + ), StreamKitError, > { // Create broadcast @@ -1075,46 +1756,114 @@ impl MoqPeerNode { StreamKitError::Runtime(format!("Failed to create broadcast '{broadcast_name}'")) })?; - // Create audio track - let audio_track = moq_lite::Track { name: "audio/data".to_string(), priority: 80 }; - let track_producer = broadcast_producer.create_track(audio_track.clone()); - let track_producer: hang::container::OrderedProducer = track_producer.into(); + // Create audio track (if audio input connected) + let audio_track = if media.has_audio { + let track = moq_lite::Track { name: "audio/data".to_string(), priority: 80 }; + let producer = broadcast_producer.create_track(track.clone()).map_err(|e| { + StreamKitError::Runtime(format!("Failed to create audio track: {e}")) + })?; + Some((track, hang::container::OrderedProducer::from(producer))) + } else { + None + }; + + // Create video track (if video input connected) + let video_track = if media.has_video { + let track = moq_lite::Track { name: "video/data".to_string(), priority: 60 }; + let producer = broadcast_producer.create_track(track.clone()).map_err(|e| { + StreamKitError::Runtime(format!("Failed to create video track: {e}")) + })?; + Some((track, hang::container::OrderedProducer::from(producer))) + } else { + None + }; // Create and publish catalog - let catalog_producer = - Self::create_and_publish_catalog(&mut broadcast_producer, &audio_track)?; - - Ok((broadcast_producer, track_producer, catalog_producer)) + let catalog_producer = Self::create_and_publish_catalog( + &mut broadcast_producer, + audio_track.as_ref().map(|(t, _)| t), + video_track.as_ref().map(|(t, _)| t), + media.video_width, + media.video_height, + )?; + + Ok(( + broadcast_producer, + audio_track.map(|(_, p)| p), + video_track.map(|(_, p)| p), + catalog_producer, + )) } - /// Create and publish the catalog with audio track info + /// Create and publish the catalog with audio and/or video track info fn create_and_publish_catalog( broadcast_producer: &mut moq_lite::BroadcastProducer, - audio_track: &moq_lite::Track, + audio_track: Option<&moq_lite::Track>, + video_track: Option<&moq_lite::Track>, + video_width: u32, + video_height: u32, ) -> Result { let mut audio_renditions = std::collections::BTreeMap::new(); - audio_renditions.insert( - audio_track.name.clone(), - hang::catalog::AudioConfig { - codec: hang::catalog::AudioCodec::Opus, - sample_rate: 48000, - channel_count: 1, - bitrate: Some(64_000), - description: None, - container: hang::catalog::Container::default(), - jitter: None, - }, - ); + if let Some(audio_track) = audio_track { + audio_renditions.insert( + audio_track.name.clone(), + hang::catalog::AudioConfig { + codec: hang::catalog::AudioCodec::Opus, + sample_rate: 48000, + channel_count: 1, + bitrate: Some(64_000), + description: None, + container: hang::catalog::Container::default(), + jitter: None, + }, + ); + } + + let mut video_renditions = std::collections::BTreeMap::new(); + if let Some(video_track) = video_track { + video_renditions.insert( + video_track.name.clone(), + hang::catalog::VideoConfig { + codec: hang::catalog::VideoCodec::VP9(hang::catalog::VP9 { + profile: 0, + level: 10, + bit_depth: 8, + ..hang::catalog::VP9::default() + }), + coded_width: Some(video_width), + coded_height: Some(video_height), + display_ratio_width: None, + display_ratio_height: None, + framerate: Some(30.0), + bitrate: None, + description: None, + optimize_for_latency: Some(true), + container: hang::catalog::Container::default(), + jitter: None, + }, + ); + } let catalog = hang::catalog::Catalog { audio: hang::catalog::Audio { renditions: audio_renditions }, + video: hang::catalog::Video { + renditions: video_renditions, + display: None, + rotation: None, + flip: None, + }, ..Default::default() }; - let mut catalog_producer = - broadcast_producer.create_track(hang::catalog::Catalog::default_track()); - let catalog_json = super::catalog_to_json(&catalog)?; - catalog_producer.write_frame(catalog_json.into_bytes()); + let mut catalog_producer = broadcast_producer + .create_track(hang::catalog::Catalog::default_track()) + .map_err(|e| StreamKitError::Runtime(format!("Failed to create catalog track: {e}")))?; + let catalog_json = catalog + .to_string() + .map_err(|e| StreamKitError::Runtime(format!("Failed to serialize catalog: {e}")))?; + catalog_producer + .write_frame(catalog_json.into_bytes()) + .map_err(|e| StreamKitError::Runtime(format!("Failed to write catalog frame: {e}")))?; Ok(catalog_producer) } @@ -1122,7 +1871,8 @@ impl MoqPeerNode { /// Run the main send loop, forwarding packets to the subscriber #[allow(clippy::too_many_arguments)] async fn run_subscriber_send_loop( - track_producer: &mut hang::container::OrderedProducer, + audio_track_producer: &mut Option, + video_track_producer: &mut Option, mut broadcast_rx: broadcast::Receiver, shutdown_rx: &mut broadcast::Receiver<()>, output_group_duration_ms: u64, @@ -1135,7 +1885,8 @@ impl MoqPeerNode { let mut last_log = std::time::Instant::now(); let mut frame_count = 0u64; let group_duration_ms = output_group_duration_ms.max(1); - let mut clock = MediaClock::new(output_initial_delay_ms); + let mut audio_clock = MediaClock::new(output_initial_delay_ms); + let mut video_clock = MediaClock::new(output_initial_delay_ms); let meter = opentelemetry::global::meter("skit_nodes"); let gap_histogram = meter .f64_histogram("moq.peer.inter_frame_ms") @@ -1146,22 +1897,26 @@ impl MoqPeerNode { opentelemetry::KeyValue::new("node_id", node_id), opentelemetry::KeyValue::new("broadcast", broadcast_name), ]; - let mut last_ts_ms: Option = None; + let mut last_audio_ts_ms: Option = None; + let mut last_video_ts_ms: Option = None; loop { tokio::select! { recv_result = broadcast_rx.recv() => { match Self::handle_broadcast_recv( recv_result, - track_producer, + audio_track_producer, + video_track_producer, &mut packet_count, &mut frame_count, &mut last_log, group_duration_ms, - &mut clock, + &mut audio_clock, + &mut video_clock, &gap_histogram, &metric_labels, - &mut last_ts_ms, + &mut last_audio_ts_ms, + &mut last_video_ts_ms, stats_delta_tx, )? { SendResult::Continue => {} @@ -1178,19 +1933,22 @@ impl MoqPeerNode { Ok(packet_count) } - /// Handle a single broadcast receive result + /// Handle a single broadcast receive result, routing to the correct track producer. #[allow(clippy::too_many_arguments, clippy::cast_precision_loss)] fn handle_broadcast_recv( recv_result: Result, - track_producer: &mut hang::container::OrderedProducer, + audio_track_producer: &mut Option, + video_track_producer: &mut Option, packet_count: &mut u64, frame_count: &mut u64, last_log: &mut std::time::Instant, group_duration_ms: u64, - clock: &mut MediaClock, + audio_clock: &mut MediaClock, + video_clock: &mut MediaClock, gap_histogram: &opentelemetry::metrics::Histogram, metric_labels: &[opentelemetry::KeyValue], - last_ts_ms: &mut Option, + last_audio_ts_ms: &mut Option, + last_video_ts_ms: &mut Option, stats_delta_tx: &mpsc::Sender, ) -> Result { match recv_result { @@ -1204,9 +1962,25 @@ impl MoqPeerNode { *last_log = std::time::Instant::now(); } - let is_first = *packet_count == 1; + // Select the appropriate clock, last_ts, and track producer based on media kind + let (clock, last_ts_ms, track_producer) = match broadcast_frame.kind { + MediaKind::Audio => (audio_clock, last_audio_ts_ms, audio_track_producer), + MediaKind::Video => (video_clock, last_video_ts_ms, video_track_producer), + }; + + let Some(track_producer) = track_producer else { + // No track producer for this media kind — skip frame + return Ok(SendResult::Continue); + }; + let timestamp_ms = clock.timestamp_ms(); - let keyframe = is_first || clock.is_group_boundary_ms(group_duration_ms); + // For audio, use time-based group boundaries; for video, use keyframe flag + let keyframe = match broadcast_frame.kind { + MediaKind::Audio => { + *packet_count == 1 || clock.is_group_boundary_ms(group_duration_ms) + }, + MediaKind::Video => broadcast_frame.keyframe, + }; if let Some(prev) = *last_ts_ms { let gap = timestamp_ms.saturating_sub(prev); @@ -1222,18 +1996,29 @@ impl MoqPeerNode { let mut payload = hang::container::BufList::new(); payload.push_chunk(broadcast_frame.data); - let frame = hang::container::Frame { timestamp, keyframe, payload }; + if keyframe { + if let Err(e) = track_producer.keyframe() { + tracing::warn!(kind = ?broadcast_frame.kind, "Failed to signal keyframe: {e}"); + let _ = stats_delta_tx + .try_send(NodeStatsDelta { errored: 1, ..Default::default() }); + return Ok(SendResult::Stop); + } + } + + let frame = hang::container::Frame { timestamp, payload }; if let Err(e) = track_producer.write(frame) { - tracing::warn!("Failed to write MoQ frame to subscriber: {e}"); + tracing::warn!(kind = ?broadcast_frame.kind, "Failed to write MoQ frame to subscriber: {e}"); let _ = stats_delta_tx .try_send(NodeStatsDelta { errored: 1, ..Default::default() }); return Ok(SendResult::Stop); } - clock.advance_by_duration_us( - broadcast_frame.duration_us, - super::constants::DEFAULT_AUDIO_FRAME_DURATION_US, - ); + + let default_duration = match broadcast_frame.kind { + MediaKind::Audio => super::constants::DEFAULT_AUDIO_FRAME_DURATION_US, + MediaKind::Video => crate::video::DEFAULT_VIDEO_FRAME_DURATION_US, + }; + clock.advance_by_duration_us(broadcast_frame.duration_us, default_duration); Ok(SendResult::Continue) }, Err(broadcast::error::RecvError::Lagged(n)) => { diff --git a/crates/nodes/src/transport/moq/pull.rs b/crates/nodes/src/transport/moq/pull.rs index d2f1b09f..90367834 100644 --- a/crates/nodes/src/transport/moq/pull.rs +++ b/crates/nodes/src/transport/moq/pull.rs @@ -7,13 +7,12 @@ use super::constants::DEFAULT_AUDIO_FRAME_DURATION_US; use async_trait::async_trait; use bytes::Buf; -use moq_lite::coding::Decode; use moq_lite::AsPath; use schemars::JsonSchema; use serde::Deserialize; use std::time::Duration; use streamkit_core::timing::MediaClock; -use streamkit_core::types::{Packet, PacketMetadata, PacketType}; +use streamkit_core::types::{AudioCodec, EncodedAudioFormat, Packet, PacketMetadata, PacketType}; use streamkit_core::{ state_helpers, stats::NodeStatsTracker, InputPin, NodeContext, OutputPin, PinCardinality, ProcessorNode, StreamKitError, @@ -36,7 +35,7 @@ pub struct MoqPullConfig { } /// A node that connects to a MoQ server, subscribes to a broadcast, -/// and outputs the received media as Opus packets. +/// and outputs the received media as encoded Opus packets. /// /// This node performs catalog discovery during initialization. /// @@ -58,7 +57,10 @@ impl MoqPullNode { // Start with a single stable output pin. output_pins: vec![OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }], } @@ -67,7 +69,10 @@ impl MoqPullNode { fn stable_out_pin() -> OutputPin { OutputPin { name: "out".to_string(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, } } @@ -81,7 +86,10 @@ impl MoqPullNode { } pins.push(OutputPin { name: track.name.clone(), - produces_type: PacketType::OpusAudio, + produces_type: PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), cardinality: PinCardinality::Broadcast, }); } @@ -244,10 +252,12 @@ impl MoqPullNode { fn strip_hang_timestamp_header( mut payload: bytes::Bytes, ) -> Result<(u64, bytes::Bytes), moq_lite::Error> { - // hang protocol: frame payload is prefixed with a varint u64 timestamp in microseconds. + // hang protocol: frame payload is prefixed with a varint timestamp in microseconds. // We parse it and forward the remaining bytes (Opus frame data). - let timestamp_micros = u64::decode(&mut payload, moq_lite::lite::Version::Draft02)?; - Ok((timestamp_micros, payload.copy_to_bytes(payload.remaining()))) + let timestamp = hang::container::Timestamp::decode(&mut payload)?; + #[allow(clippy::cast_possible_truncation)] // MoQ timestamps fit in u64 + let timestamp_us = timestamp.as_micros() as u64; + Ok((timestamp_us, payload.copy_to_bytes(payload.remaining()))) } async fn read_next_raw_moq( @@ -316,7 +326,10 @@ impl MoqPullNode { }; // Subscribe to the catalog track - let raw_catalog_track = broadcast.subscribe_track(&hang::catalog::Catalog::default_track()); + let raw_catalog_track = + broadcast.subscribe_track(&hang::catalog::Catalog::default_track()).map_err(|e| { + StreamKitError::Runtime(format!("Failed to subscribe to catalog track: {e}")) + })?; let mut catalog_consumer = hang::catalog::CatalogConsumer::new(raw_catalog_track); // Parse the catalog to discover tracks @@ -493,7 +506,10 @@ impl MoqPullNode { tracing::info!("Subscribed to broadcast '{}'", self.config.broadcast); // First, get the catalog to find audio tracks - let raw_catalog_track = broadcast.subscribe_track(&hang::catalog::Catalog::default_track()); + let raw_catalog_track = + broadcast.subscribe_track(&hang::catalog::Catalog::default_track()).map_err(|e| { + StreamKitError::Runtime(format!("Failed to subscribe to catalog track: {e}")) + })?; let mut catalog_consumer = hang::catalog::CatalogConsumer::new(raw_catalog_track); tracing::debug!( @@ -526,7 +542,9 @@ impl MoqPullNode { // // For audio we prefer low-latency, "latest group" semantics: we always read the latest // announced group and drain it, letting moq_lite drop old groups if we're slow. - let mut track_consumer = broadcast.subscribe_track(audio_track); + let mut track_consumer = broadcast.subscribe_track(audio_track).map_err(|e| { + StreamKitError::Runtime(format!("Failed to subscribe to audio track: {e}")) + })?; let mut current_group: Option = None; let mut session_packet_count: u32 = 0; @@ -655,6 +673,7 @@ impl MoqPullNode { timestamp_us: Some(timestamp_us), duration_us, sequence: None, + keyframe: None, }); last_timestamp_us = Some(timestamp_us); @@ -714,6 +733,7 @@ impl MoqPullNode { timestamp_us: Some(timestamp_us), duration_us, sequence: None, + keyframe: None, }); last_timestamp_us = Some(timestamp_us); @@ -789,10 +809,10 @@ impl MoqPullNode { } #[cfg(test)] +#[allow(clippy::unwrap_used, clippy::expect_used)] mod tests { use super::*; use bytes::BytesMut; - use moq_lite::coding::Encode; #[test] fn test_output_pins_for_tracks_includes_stable_out() { @@ -812,7 +832,10 @@ mod tests { #[test] fn test_strip_hang_timestamp_header() { let mut buf = BytesMut::new(); - 123_u64.encode(&mut buf, moq_lite::lite::Version::Draft02); + hang::container::Timestamp::from_micros(123) + .expect("valid timestamp") + .encode(&mut buf) + .expect("encode succeeds"); buf.extend_from_slice(b"opus-frame-bytes"); let payload = buf.freeze(); diff --git a/crates/nodes/src/transport/moq/push.rs b/crates/nodes/src/transport/moq/push.rs index 93187da9..6e6ccaf5 100644 --- a/crates/nodes/src/transport/moq/push.rs +++ b/crates/nodes/src/transport/moq/push.rs @@ -10,7 +10,7 @@ use opentelemetry::{global, KeyValue}; use schemars::JsonSchema; use serde::Deserialize; use streamkit_core::timing::MediaClock; -use streamkit_core::types::{Packet, PacketType}; +use streamkit_core::types::{AudioCodec, EncodedAudioFormat, Packet, PacketType}; use streamkit_core::{ packet_helpers, state_helpers, stats::NodeStatsTracker, InputPin, NodeContext, OutputPin, PinCardinality, ProcessorNode, StreamKitError, @@ -80,7 +80,10 @@ impl ProcessorNode for MoqPushNode { fn input_pins(&self) -> Vec { vec![InputPin { name: "in".to_string(), - accepts_types: vec![PacketType::OpusAudio], + accepts_types: vec![PacketType::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + })], cardinality: PinCardinality::One, }] } @@ -146,7 +149,9 @@ impl ProcessorNode for MoqPushNode { // Match @moq/hang defaults for interoperability. let audio_track = moq_lite::Track { name: "audio/data".to_string(), priority: 80 }; - let track_producer = broadcast.create_track(audio_track.clone()); + let track_producer = broadcast + .create_track(audio_track.clone()) + .map_err(|e| StreamKitError::Runtime(format!("Failed to create audio track: {e}")))?; let mut track_producer: hang::container::OrderedProducer = track_producer.into(); // Create and publish a catalog describing our audio track @@ -170,12 +175,15 @@ impl ProcessorNode for MoqPushNode { }; // Create catalog track and publish the catalog data - let mut catalog_producer = broadcast.create_track(hang::catalog::Catalog::default_track()); - let catalog_json = match super::catalog_to_json(&catalog) { + let mut catalog_producer = broadcast + .create_track(hang::catalog::Catalog::default_track()) + .map_err(|e| StreamKitError::Runtime(format!("Failed to create catalog track: {e}")))?; + let catalog_json = match catalog.to_string() { Ok(json) => json, Err(e) => { - state_helpers::emit_failed(&context.state_tx, &node_name, e.to_string()); - return Err(e); + let err = StreamKitError::Runtime(format!("Failed to serialize catalog: {e}")); + state_helpers::emit_failed(&context.state_tx, &node_name, err.to_string()); + return Err(err); }, }; let catalog_data = catalog_json.into_bytes(); // Avoid intermediate Vec allocation @@ -186,7 +194,9 @@ impl ProcessorNode for MoqPushNode { ); // Write the catalog frame - catalog_producer.write_frame(catalog_data); + catalog_producer + .write_frame(catalog_data) + .map_err(|e| StreamKitError::Runtime(format!("Failed to write catalog frame: {e}")))?; // Keep the catalog track producer alive for the lifetime of the broadcast. // If dropped, the underlying moq-lite track gets cancelled and watchers will go "offline". let _catalog_producer = catalog_producer; @@ -261,7 +271,18 @@ impl ProcessorNode for MoqPushNode { let mut payload = hang::container::BufList::new(); payload.push_chunk(data); - let frame = hang::container::Frame { timestamp, keyframe, payload }; + if keyframe { + if let Err(e) = track_producer.keyframe() { + let err_msg = format!("Failed to signal keyframe: {e}"); + tracing::warn!("{err_msg}"); + stats_tracker.errored(); + stats_tracker.force_send(); + state_helpers::emit_failed(&context.state_tx, &node_name, &err_msg); + return Err(StreamKitError::Runtime(err_msg)); + } + } + + let frame = hang::container::Frame { timestamp, payload }; if let Err(e) = track_producer.write(frame) { let err_msg = format!("Failed to write MoQ frame: {e}"); @@ -312,7 +333,7 @@ impl ProcessorNode for MoqPushNode { state_helpers::emit_stopped(&context.state_tx, &node_name, "input_closed"); // Close the track when done (best-effort) - track_producer.track.clone().close(); + let _ = track_producer.track.finish(); tracing::info!("MoqPushNode finished after sending {} packets", packet_count); Ok(()) diff --git a/crates/nodes/src/video/colorbars.rs b/crates/nodes/src/video/colorbars.rs new file mode 100644 index 00000000..d935fe3d --- /dev/null +++ b/crates/nodes/src/video/colorbars.rs @@ -0,0 +1,782 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! SMPTE EIA 75% color bars video generator. +//! +//! Produces raw video frames with the standard 7-bar test pattern. +//! Supports RGBA8 (default), NV12, and I420 pixel formats. +//! Configurable resolution, frame rate, and frame count. +//! +//! - `frame_count > 0`: batch mode — emits exactly N frames with synthetic timestamps (oneshot). +//! - `frame_count == 0`: real-time mode — emits indefinitely, paced by `tokio::time::interval` (dynamic). + +use async_trait::async_trait; +use schemars::JsonSchema; +use serde::Deserialize; +use streamkit_core::stats::NodeStatsTracker; +use streamkit_core::types::{Packet, PacketMetadata, PacketType, PixelFormat, VideoFormat}; +use streamkit_core::{ + config_helpers, state_helpers, InputPin, NodeContext, NodeRegistry, OutputPin, PinCardinality, + ProcessorNode, StreamKitError, +}; + +use schemars::schema_for; +use streamkit_core::registry::StaticPins; + +const fn default_width() -> u32 { + 640 +} + +const fn default_height() -> u32 { + 480 +} + +const fn default_fps() -> u32 { + 30 +} + +const fn default_frame_count() -> u32 { + 0 +} + +/// Configuration for the SMPTE color bars generator. +#[derive(Debug, Clone, Deserialize, JsonSchema)] +#[serde(default)] +pub struct ColorBarsConfig { + /// Frame width in pixels. + #[serde(default = "default_width")] + pub width: u32, + /// Frame height in pixels. + #[serde(default = "default_height")] + pub height: u32, + /// Frames per second. + #[serde(default = "default_fps")] + pub fps: u32, + /// Total frames to generate. 0 = infinite (real-time pacing). + #[serde(default = "default_frame_count")] + pub frame_count: u32, + /// Output pixel format. Supported: "rgba8" (default), "nv12", and "i420". + #[serde(default = "default_pixel_format")] + pub pixel_format: String, + /// When `true`, draws the current wall-clock time (`HH:MM:SS.mmm`) + /// onto each generated frame using a monospace font. + #[serde(default)] + pub draw_time: bool, + /// Optional filesystem path to a custom TTF/OTF font used for the + /// `draw_time` overlay. When omitted the bundled DejaVu Sans Mono + /// font (embedded in the binary) is used. + #[serde(default)] + pub draw_time_font_path: Option, + /// When `true`, horizontally scrolls the color bars each frame so that + /// every frame differs substantially from the previous one. Useful for + /// encoding benchmarks where static content would compress to nearly + /// nothing. + #[serde(default)] + pub animate: bool, +} + +fn default_pixel_format() -> String { + "rgba8".to_string() +} + +// Re-export the shared parse_pixel_format from the parent module. +use super::parse_pixel_format; + +impl Default for ColorBarsConfig { + fn default() -> Self { + Self { + width: default_width(), + height: default_height(), + fps: default_fps(), + frame_count: default_frame_count(), + pixel_format: default_pixel_format(), + draw_time: false, + draw_time_font_path: None, + animate: false, + } + } +} + +/// Source node that generates SMPTE EIA 75% color bar frames. +/// +/// No input pins. Outputs `PacketType::RawVideo` on `"out"` in the +/// configured pixel format (I420 or RGBA8). +/// Follows the Ready → Start lifecycle (like `FileReadNode`). +pub struct ColorBarsNode { + config: ColorBarsConfig, + /// Resolved pixel format from the config string. + pixel_format: PixelFormat, +} + +#[async_trait] +impl ProcessorNode for ColorBarsNode { + fn input_pins(&self) -> Vec { + vec![] + } + + fn output_pins(&self) -> Vec { + vec![OutputPin { + name: "out".to_string(), + produces_type: PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: self.pixel_format, + }), + cardinality: PinCardinality::Broadcast, + }] + } + + async fn run(self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { + let node_name = context.output_sender.node_name().to_string(); + state_helpers::emit_initializing(&context.state_tx, &node_name); + + let width = self.config.width; + let height = self.config.height; + let fps = self.config.fps; + let frame_count = self.config.frame_count; + let duration_us = if fps > 0 { 1_000_000 / u64::from(fps) } else { 33_333 }; + + tracing::info!( + "ColorBarsNode: {}x{} @ {} fps, frame_count={}", + width, + height, + fps, + frame_count + ); + + let pixel_format = self.pixel_format; + + // Pre-load the monospace font for draw_time (once, if enabled). + let draw_time_font = if self.config.draw_time { + // If the user specified a custom font path, try that first; + // otherwise use the compile-time embedded DejaVu Sans Mono. + let font_bytes = self.config.draw_time_font_path.as_ref().map_or_else( + || crate::video::fonts::DEFAULT_MONO_FONT_DATA.to_vec(), + |path| match std::fs::read(path) { + Ok(bytes) => { + tracing::info!("draw_time: loaded custom font from {path}"); + bytes + }, + Err(e) => { + tracing::warn!( + "draw_time: failed to read custom font '{path}': {e}, \ + falling back to bundled DejaVu Sans Mono" + ); + crate::video::fonts::DEFAULT_MONO_FONT_DATA.to_vec() + }, + }, + ); + + match fontdue::Font::from_bytes(font_bytes, fontdue::FontSettings::default()) { + Ok(f) => { + tracing::info!("draw_time enabled: font ready"); + Some(f) + }, + Err(e) => { + tracing::warn!("draw_time: failed to parse font: {e}"); + None + }, + } + } else { + None + }; + + // Pre-generate the color bar pattern into a template buffer. + let layout = streamkit_core::types::VideoLayout::packed(width, height, pixel_format); + let total_bytes = layout.total_bytes(); + let mut template = vec![0u8; total_bytes]; + match pixel_format { + PixelFormat::I420 => { + generate_smpte_colorbars_i420(width, height, &mut template, &layout); + }, + PixelFormat::Nv12 => { + generate_smpte_colorbars_nv12(width, height, &mut template, &layout); + }, + PixelFormat::Rgba8 => generate_smpte_colorbars_rgba8(width, height, &mut template), + } + + // Source nodes emit Ready state and wait for Start signal. + state_helpers::emit_ready(&context.state_tx, &node_name); + tracing::info!("ColorBarsNode ready, waiting for start signal"); + + loop { + match context.control_rx.recv().await { + Some(streamkit_core::control::NodeControlMessage::Start) => { + tracing::info!("ColorBarsNode received start signal"); + break; + }, + Some(streamkit_core::control::NodeControlMessage::UpdateParams(_)) => {}, + Some(streamkit_core::control::NodeControlMessage::Shutdown) => { + tracing::info!("ColorBarsNode received shutdown before start"); + return Ok(()); + }, + None => { + tracing::warn!("Control channel closed before start signal received"); + return Ok(()); + }, + } + } + + state_helpers::emit_running(&context.state_tx, &node_name); + + let mut stats_tracker = NodeStatsTracker::new(node_name.clone(), context.stats_tx.clone()); + + // Set up real-time pacing for dynamic (frame_count == 0) mode. + let mut interval = if frame_count == 0 && fps > 0 { + let period = std::time::Duration::from_micros(duration_us); + Some(tokio::time::interval(period)) + } else { + None + }; + + let mut seq: u64 = 0; + + loop { + // Honour finite frame count. + if frame_count > 0 && seq >= u64::from(frame_count) { + tracing::info!("ColorBarsNode finished after {} frames", seq); + break; + } + + // Check cancellation. + if let Some(token) = &context.cancellation_token { + if token.is_cancelled() { + tracing::info!("ColorBarsNode cancelled after {} frames", seq); + break; + } + } + + // Pace in real-time mode. + if let Some(ref mut iv) = interval { + tokio::select! { + _ = iv.tick() => {}, + Some(msg) = context.control_rx.recv() => { + match msg { + streamkit_core::control::NodeControlMessage::Shutdown => { + tracing::info!("ColorBarsNode received shutdown during generation"); + break; + }, + streamkit_core::control::NodeControlMessage::UpdateParams(_) + | streamkit_core::control::NodeControlMessage::Start => {}, + } + continue; + } + } + } + + let timestamp_us = seq * duration_us; + let metadata = Some(PacketMetadata { + timestamp_us: Some(timestamp_us), + duration_us: Some(duration_us), + sequence: Some(seq), + keyframe: Some(true), + }); + + // Allocate frame from pool if available, otherwise from vec. + let animate = self.config.animate; + let frame = if let Some(pool) = &context.video_pool { + let mut pooled = pool.get(total_bytes); + #[allow(clippy::cast_possible_truncation)] + if animate { + let offset_px = seq as usize * ANIMATE_SCROLL_PX; + scroll_frame( + &template, + pooled.as_mut_slice(), + pixel_format, + &layout, + offset_px, + ); + } else { + pooled.as_mut_slice()[..total_bytes].copy_from_slice(&template); + } + if let Some(ref font) = draw_time_font { + stamp_time(pooled.as_mut_slice(), width, height, pixel_format, &layout, font); + } + streamkit_core::types::VideoFrame::from_pooled( + width, + height, + pixel_format, + pooled, + metadata, + )? + } else { + #[allow(clippy::cast_possible_truncation)] + let mut data = if animate { + let offset_px = seq as usize * ANIMATE_SCROLL_PX; + let mut buf = vec![0u8; total_bytes]; + scroll_frame(&template, &mut buf, pixel_format, &layout, offset_px); + buf + } else { + template.clone() + }; + if let Some(ref font) = draw_time_font { + stamp_time(&mut data, width, height, pixel_format, &layout, font); + } + streamkit_core::types::VideoFrame::with_metadata( + width, + height, + pixel_format, + data, + metadata, + )? + }; + + if context.output_sender.send("out", Packet::Video(frame)).await.is_err() { + tracing::debug!("Output channel closed, stopping ColorBarsNode"); + break; + } + + stats_tracker.sent(); + stats_tracker.maybe_send(); + seq += 1; + } + + stats_tracker.force_send(); + state_helpers::emit_stopped(&context.state_tx, &node_name, "completed"); + Ok(()) + } +} + +// ── SMPTE color bar generation ────────────────────────────────────────────── + +/// SMPTE EIA 75% color bars (ITU-R BT.601 Y'CbCr). +/// +/// Seven equal-width vertical bars, left to right: +/// White, Yellow, Cyan, Green, Magenta, Red, Blue +/// +/// 75% amplitude values (studio range): +/// | Bar | Y | U (Cb) | V (Cr) | +/// |---------|------|----------|----------| +/// | White | 180 | 128 | 128 | +/// | Yellow | 162 | 44 | 142 | +/// | Cyan | 131 | 156 | 44 | +/// | Green | 112 | 72 | 58 | +/// | Magenta | 84 | 184 | 198 | +/// | Red | 65 | 100 | 212 | +/// | Blue | 35 | 212 | 114 | +const SMPTE_BARS_YUV: [(u8, u8, u8); 7] = [ + (180, 128, 128), // white + (162, 44, 142), // yellow + (131, 156, 44), // cyan + (112, 72, 58), // green + (84, 184, 198), // magenta + (65, 100, 212), // red + (35, 212, 114), // blue +]; + +/// SMPTE EIA 75% color bars in RGBA8 format. +/// +/// Same bar order and approximate 75% amplitude as the YUV table, +/// converted to full-range RGB. +const SMPTE_BARS_RGBA: [(u8, u8, u8, u8); 7] = [ + (191, 191, 191, 255), // white (75%) + (191, 191, 0, 255), // yellow + (0, 191, 191, 255), // cyan + (0, 191, 0, 255), // green + (191, 0, 191, 255), // magenta + (191, 0, 0, 255), // red + (0, 0, 191, 255), // blue +]; + +/// Fills an RGBA8 buffer with SMPTE 75% color bars. +fn generate_smpte_colorbars_rgba8(width: u32, height: u32, data: &mut [u8]) { + let bar_count = SMPTE_BARS_RGBA.len(); + let stride = width as usize * 4; + for row in 0..height as usize { + for col in 0..width as usize { + let bar_idx = col * bar_count / width as usize; + let (r, g, b, a) = SMPTE_BARS_RGBA[bar_idx]; + let offset = row * stride + col * 4; + data[offset] = r; + data[offset + 1] = g; + data[offset + 2] = b; + data[offset + 3] = a; + } + } +} + +/// Fills an NV12 buffer with SMPTE 75% color bars. +/// +/// Same YUV values as I420 but U and V are interleaved in a single chroma plane. +fn generate_smpte_colorbars_nv12( + width: u32, + height: u32, + data: &mut [u8], + layout: &streamkit_core::types::VideoLayout, +) { + let planes = layout.planes(); + let y_plane = planes[0]; + let uv_plane = planes[1]; + + let bar_count = SMPTE_BARS_YUV.len(); + + // Fill Y plane (identical to I420). + for row in 0..height as usize { + for col in 0..width as usize { + let bar_idx = col * bar_count / width as usize; + let (y, _, _) = SMPTE_BARS_YUV[bar_idx]; + data[y_plane.offset + row * y_plane.stride + col] = y; + } + } + + // Fill interleaved UV plane (half resolution). + let chroma_w = (width + 1) as usize / 2; + let chroma_h = uv_plane.height as usize; + for row in 0..chroma_h { + for col in 0..chroma_w { + let src_col = col * 2; + let bar_idx = src_col * bar_count / width as usize; + let (_, u, v) = SMPTE_BARS_YUV[bar_idx]; + let offset = uv_plane.offset + row * uv_plane.stride + col * 2; + data[offset] = u; + data[offset + 1] = v; + } + } +} + +/// Fills an I420 buffer with SMPTE 75% color bars. +fn generate_smpte_colorbars_i420( + width: u32, + height: u32, + data: &mut [u8], + layout: &streamkit_core::types::VideoLayout, +) { + let planes = layout.planes(); + let y_plane = planes[0]; + let u_plane = planes[1]; + let v_plane = planes[2]; + + let bar_count = SMPTE_BARS_YUV.len(); + + // Fill Y plane. + for row in 0..height as usize { + for col in 0..width as usize { + let bar_idx = col * bar_count / width as usize; + let (y, _, _) = SMPTE_BARS_YUV[bar_idx]; + data[y_plane.offset + row * y_plane.stride + col] = y; + } + } + + // Fill U and V planes (half resolution for I420). + let chroma_w = u_plane.width as usize; + let chroma_h = u_plane.height as usize; + for row in 0..chroma_h { + for col in 0..chroma_w { + let src_col = col * 2; + let bar_idx = src_col * bar_count / width as usize; + let (_, u, v) = SMPTE_BARS_YUV[bar_idx]; + data[u_plane.offset + row * u_plane.stride + col] = u; + data[v_plane.offset + row * v_plane.stride + col] = v; + } + } +} + +// ── Animation (horizontal scroll) ─────────────────────────────────────────── + +/// Pixels scrolled per frame when `animate` is enabled. +const ANIMATE_SCROLL_PX: usize = 4; + +/// Horizontally rotate a single plane by `offset_bytes`, writing into `dst`. +#[allow(clippy::cast_possible_truncation)] +fn rotate_plane_rows( + src: &[u8], + dst: &mut [u8], + plane_offset: usize, + stride: usize, + data_width: usize, + height: usize, + offset_bytes: usize, +) { + let off = offset_bytes % data_width; + if off == 0 { + let len = stride * height; + dst[plane_offset..plane_offset + len] + .copy_from_slice(&src[plane_offset..plane_offset + len]); + return; + } + for row in 0..height { + let base = plane_offset + row * stride; + dst[base..base + data_width - off].copy_from_slice(&src[base + off..base + data_width]); + dst[base + data_width - off..base + data_width].copy_from_slice(&src[base..base + off]); + } +} + +/// Scroll the entire frame (all planes) by `offset_px` luma pixels. +#[allow(clippy::cast_possible_truncation)] +fn scroll_frame( + template: &[u8], + dst: &mut [u8], + pixel_format: PixelFormat, + layout: &streamkit_core::types::VideoLayout, + offset_px: usize, +) { + // Round down to even so chroma stays aligned with 4:2:0 subsampling. + let offset_px = offset_px & !1; + let planes = layout.planes(); + + match pixel_format { + PixelFormat::Rgba8 => { + let p = planes[0]; + rotate_plane_rows( + template, + dst, + p.offset, + p.stride, + p.stride, + p.height as usize, + offset_px * 4, + ); + }, + PixelFormat::I420 => { + let y = planes[0]; + rotate_plane_rows( + template, + dst, + y.offset, + y.stride, + y.stride, + y.height as usize, + offset_px, + ); + let chroma_off = offset_px / 2; + let u = planes[1]; + rotate_plane_rows( + template, + dst, + u.offset, + u.stride, + u.stride, + u.height as usize, + chroma_off, + ); + let v = planes[2]; + rotate_plane_rows( + template, + dst, + v.offset, + v.stride, + v.stride, + v.height as usize, + chroma_off, + ); + }, + PixelFormat::Nv12 => { + let y = planes[0]; + rotate_plane_rows( + template, + dst, + y.offset, + y.stride, + y.stride, + y.height as usize, + offset_px, + ); + // UV plane: each chroma position is 2 bytes (U+V interleaved). + let uv = planes[1]; + let chroma_off_bytes = (offset_px / 2) * 2; + rotate_plane_rows( + template, + dst, + uv.offset, + uv.stride, + uv.stride, + uv.height as usize, + chroma_off_bytes, + ); + }, + } +} + +// ── draw_time stamping ────────────────────────────────────────────────────── + +/// Font size (px) used for the wall-clock timestamp overlay. +const DRAW_TIME_FONT_SIZE: f32 = 24.0; + +/// Stamp the current wall-clock time (`HH:MM:SS.mmm`) onto a frame buffer. +/// +/// Works for both RGBA8 and I420 pixel formats. For I420 the text is +/// rasterized into a tiny RGBA scratch area and then each lit pixel is +/// converted to YUV and poked into the Y/U/V planes. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_precision_loss, + clippy::cast_possible_wrap +)] +fn stamp_time( + data: &mut [u8], + width: u32, + height: u32, + pixel_format: PixelFormat, + layout: &streamkit_core::types::VideoLayout, + font: &fontdue::Font, +) { + use std::time::SystemTime; + + let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap_or_default(); + let total_secs = now.as_secs(); + let millis = now.subsec_millis(); + let secs = total_secs % 60; + let mins = (total_secs / 60) % 60; + let hrs = (total_secs / 3600) % 24; + let time_str = format!("{hrs:02}:{mins:02}:{secs:02}.{millis:03}"); + + // Placement: bottom-left with a small margin. + let margin_x: i32 = 8; + let margin_y: i32 = 8; + let origin_y = height as i32 - margin_y - DRAW_TIME_FONT_SIZE as i32; + + match pixel_format { + PixelFormat::Rgba8 => { + // White, fully opaque text. + super::blit_text_rgba( + data, + width, + height, + font, + DRAW_TIME_FONT_SIZE, + &time_str, + margin_x, + origin_y, + [255, 255, 255, 255], + ); + }, + PixelFormat::I420 | PixelFormat::Nv12 => { + // YUV formats need direct plane manipulation — no shared RGBA + // utility applies here. + let (ref_metrics, _) = font.rasterize('A', DRAW_TIME_FONT_SIZE); + let baseline_y = ref_metrics.height as f32; + + let mut cursor_x: f32 = 0.0; + + for ch in time_str.chars() { + let (metrics, bitmap) = font.rasterize(ch, DRAW_TIME_FONT_SIZE); + + let gx = margin_x + (cursor_x + metrics.xmin as f32) as i32; + let gy = + origin_y + (baseline_y - metrics.ymin as f32) as i32 - metrics.height as i32; + + for row in 0..metrics.height { + let dst_y = gy + row as i32; + if dst_y < 0 || dst_y >= height as i32 { + continue; + } + for col in 0..metrics.width { + let dst_x = gx + col as i32; + if dst_x < 0 || dst_x >= width as i32 { + continue; + } + let coverage = bitmap[row * metrics.width + col]; + if coverage == 0 { + continue; + } + + let px = dst_x as usize; + let py = dst_y as usize; + + let planes = layout.planes(); + let y_plane = planes[0]; + + // White in YUV = Y:235, U:128, V:128 + let alpha = u16::from(coverage); + let inv = 255 - alpha; + + let y_off = y_plane.offset + py * y_plane.stride + px; + let old_y = u16::from(data[y_off]); + data[y_off] = ((235 * alpha + old_y * inv + 128) / 255) as u8; + + // Chroma planes are half-resolution; update only once + // per 2×2 block (when both coords are even). + if px.is_multiple_of(2) && py.is_multiple_of(2) { + let cx = px / 2; + let cy = py / 2; + match pixel_format { + PixelFormat::I420 => { + let u_plane = planes[1]; + let v_plane = planes[2]; + let u_off = u_plane.offset + cy * u_plane.stride + cx; + let v_off = v_plane.offset + cy * v_plane.stride + cx; + let old_u = u16::from(data[u_off]); + let old_v = u16::from(data[v_off]); + data[u_off] = ((128 * alpha + old_u * inv + 128) / 255) as u8; + data[v_off] = ((128 * alpha + old_v * inv + 128) / 255) as u8; + }, + PixelFormat::Nv12 => { + let uv_plane = planes[1]; + let uv_off = uv_plane.offset + cy * uv_plane.stride + cx * 2; + let old_u = u16::from(data[uv_off]); + let old_v = u16::from(data[uv_off + 1]); + data[uv_off] = ((128 * alpha + old_u * inv + 128) / 255) as u8; + data[uv_off + 1] = + ((128 * alpha + old_v * inv + 128) / 255) as u8; + }, + PixelFormat::Rgba8 => unreachable!(), + } + } + } + } + + cursor_x += metrics.advance_width; + if (margin_x as f32 + cursor_x) >= width as f32 { + break; + } + } + }, + } +} + +// ── Registration ──────────────────────────────────────────────────────────── + +#[allow(clippy::expect_used, clippy::missing_panics_doc)] +pub fn register_colorbars_nodes(registry: &mut NodeRegistry) { + let default_node = + ColorBarsNode { config: ColorBarsConfig::default(), pixel_format: PixelFormat::Nv12 }; + registry.register_static_with_description( + "video::colorbars", + |params| { + let config: ColorBarsConfig = config_helpers::parse_config_optional(params)?; + let pixel_format = parse_pixel_format(&config.pixel_format)?; + Ok(Box::new(ColorBarsNode { config, pixel_format })) + }, + serde_json::to_value(schema_for!(ColorBarsConfig)) + .expect("ColorBarsConfig schema should serialize to JSON"), + StaticPins { inputs: default_node.input_pins(), outputs: default_node.output_pins() }, + vec!["video".to_string(), "generators".to_string()], + false, + "Generates SMPTE EIA 75% color bar test frames. \ + Supports NV12 (default), I420, and RGBA8 pixel formats via the pixel_format config. \ + Use with a video encoder for pipeline testing and validation.", + ); +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn test_smpte_colorbars_i420_dimensions() { + let width = 640u32; + let height = 480u32; + let layout = streamkit_core::types::VideoLayout::packed(width, height, PixelFormat::I420); + let total = layout.total_bytes(); + let mut data = vec![0u8; total]; + generate_smpte_colorbars_i420(width, height, &mut data, &layout); + + // Y plane: first pixel should be white (Y=180). + assert_eq!(data[0], 180); + // Last bar (rightmost column) should be blue (Y=35). + let last_y_col = (width - 1) as usize; + assert_eq!(data[last_y_col], 35); + } + + #[test] + fn test_colorbars_config_defaults() { + let config = ColorBarsConfig::default(); + assert_eq!(config.width, 640); + assert_eq!(config.height, 480); + assert_eq!(config.fps, 30); + assert_eq!(config.frame_count, 0); + } +} diff --git a/crates/nodes/src/video/compositor/config.rs b/crates/nodes/src/video/compositor/config.rs new file mode 100644 index 00000000..0ef29c91 --- /dev/null +++ b/crates/nodes/src/video/compositor/config.rs @@ -0,0 +1,311 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Configuration types for the video compositor node. + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; +use std::collections::HashMap; + +// ── Configuration ─────────────────────────────────────────────────────────── + +const fn default_width() -> u32 { + 1280 +} + +const fn default_height() -> u32 { + 720 +} + +const fn default_fps() -> u32 { + 30 +} + +/// Pixel-space rectangle for positioning a layer on the output canvas. +/// +/// `x` and `y` are signed to allow off-screen positioning (e.g. for +/// slide-in effects or rotation around the rect centre). +#[derive(Deserialize, Debug, Clone, JsonSchema)] +pub struct Rect { + pub x: i32, + pub y: i32, + pub width: u32, + pub height: u32, +} + +/// Common spatial and visual properties shared by all overlay types. +/// +/// Flattened into each overlay config via `#[serde(flatten)]` so the JSON +/// shape stays identical (fields remain at the top level). +#[derive(Deserialize, Debug, Clone, JsonSchema)] +pub struct OverlayTransform { + /// Destination rectangle on the output canvas. + pub rect: Rect, + /// Opacity multiplier (0.0 = fully transparent, 1.0 = fully opaque). + #[serde(default = "default_opacity")] + pub opacity: f32, + /// Clockwise rotation in degrees around the rect centre. Default 0.0. + #[serde(default)] + pub rotation_degrees: f32, + /// Visual stacking order. Lower values are drawn first (bottom); + /// higher values are drawn on top. Default 0. + #[serde(default = "default_z_index")] + pub z_index: i32, + /// Mirror the layer horizontally (flip left ↔ right). Default `false`. + #[serde(default)] + pub mirror_horizontal: bool, + /// Mirror the layer vertically (flip top ↔ bottom). Default `false`. + #[serde(default)] + pub mirror_vertical: bool, +} + +impl Default for OverlayTransform { + fn default() -> Self { + Self { + rect: Rect { x: 0, y: 0, width: 0, height: 0 }, + opacity: default_opacity(), + rotation_degrees: 0.0, + z_index: default_z_index(), + mirror_horizontal: false, + mirror_vertical: false, + } + } +} + +/// Configuration for a static image overlay (decoded once at init). +#[derive(Deserialize, Debug, Clone, JsonSchema)] +pub struct ImageOverlayConfig { + /// Base64-encoded image data (PNG or JPEG). Decoded once during + /// initialization, not per-frame. + pub data_base64: String, + /// Spatial and visual properties (rect, opacity, rotation, z_index). + #[serde(flatten)] + pub transform: OverlayTransform, +} + +/// Configuration for a text overlay (rasterized once per `UpdateParams`). +#[derive(Deserialize, Debug, Clone, JsonSchema)] +pub struct TextOverlayConfig { + /// The text string to render. + pub text: String, + /// Spatial and visual properties (rect, opacity, rotation, z_index). + #[serde(flatten)] + pub transform: OverlayTransform, + /// RGBA colour, e.g. `[255, 255, 255, 255]`. + #[serde(default = "default_text_color")] + pub color: [u8; 4], + /// Font size in pixels. + #[serde(default = "default_font_size")] + pub font_size: u32, + /// Optional filesystem path to a TTF/OTF font file. + /// Use this for external or system-installed fonts not in the bundled set. + /// When omitted, a bundled default font (DejaVu Sans) is used. + #[serde(default)] + pub font_path: Option, + /// Optional base64-encoded TTF/OTF font data. + /// Takes precedence over `font_path` when both are provided. + #[serde(default)] + pub font_data_base64: Option, + /// Named font from the bundled set (embedded in the binary at compile + /// time — guaranteed to work without system font packages). + /// Takes precedence over `font_path` but not `font_data_base64`. + /// Available names: "dejavu-sans", "dejavu-sans-bold", + /// "dejavu-sans-mono", "dejavu-sans-mono-bold", + /// "dejavu-serif", "dejavu-serif-bold". + #[serde(default)] + pub font_name: Option, +} + +pub(crate) const fn default_opacity() -> f32 { + 1.0 +} + +pub(crate) const fn default_z_index() -> i32 { + 0 +} + +const fn default_text_color() -> [u8; 4] { + [255, 255, 255, 255] +} + +const fn default_font_size() -> u32 { + 24 +} + +/// Layer configuration for a single compositing input. +#[derive(Deserialize, Debug, Clone, JsonSchema)] +pub struct LayerConfig { + /// Destination rectangle on the output canvas. If `None`, the input is + /// scaled to fill the entire canvas. + pub rect: Option, + /// Opacity (0.0 .. 1.0). Default 1.0. + #[serde(default = "default_opacity")] + pub opacity: f32, + /// Visual stacking order. Lower values are drawn first (bottom); + /// higher values are drawn on top. Ties are broken by slot index + /// (pin insertion order). Default 0. + #[serde(default = "default_z_index")] + pub z_index: i32, + /// Clockwise rotation in degrees. Default 0.0 (no rotation). + /// The layer is rotated around its destination rect centre. + #[serde(default)] + pub rotation_degrees: f32, + /// Mirror the layer horizontally (flip left ↔ right). Default `false`. + #[serde(default)] + pub mirror_horizontal: bool, + /// Mirror the layer vertically (flip top ↔ bottom). Default `false`. + #[serde(default)] + pub mirror_vertical: bool, +} + +impl Default for LayerConfig { + fn default() -> Self { + Self { + rect: None, + opacity: default_opacity(), + z_index: default_z_index(), + rotation_degrees: 0.0, + mirror_horizontal: false, + mirror_vertical: false, + } + } +} + +/// Configuration for the video compositor node. +/// +/// The compositor supports an arbitrary number of dynamic video inputs +/// (created at runtime via `PinManagementMessage`) plus static image/text +/// overlays configured here. +#[derive(Deserialize, Debug, Clone, JsonSchema)] +#[serde(default)] +pub struct CompositorConfig { + /// Output canvas width in pixels. + #[serde(default = "default_width")] + pub width: u32, + /// Output canvas height in pixels. + #[serde(default = "default_height")] + pub height: u32, + /// Output frame rate. The compositor ticks at this fixed rate + /// regardless of input frame rates, compositing with the latest + /// available frame from each input. + #[serde(default = "default_fps")] + pub fps: u32, + /// Number of input pins to pre-create. + /// Required for stateless/oneshot pipelines where pins must exist before + /// graph building. Optional for dynamic pipelines where pins are created + /// on-demand. If specified, pins will be named in_0, in_1, ..., in_{N-1}. + pub num_inputs: Option, + /// Per-layer configuration, keyed by pin name (e.g. `"in_0"`). + /// Layers without an entry here are scaled to fill the canvas. + #[serde(default)] + pub layers: HashMap, + /// Static image overlays (decoded once during init). + #[serde(default)] + pub image_overlays: Vec, + /// Text overlays (rasterized once per `UpdateParams`). + #[serde(default)] + pub text_overlays: Vec, +} + +impl Default for CompositorConfig { + fn default() -> Self { + Self { + width: default_width(), + height: default_height(), + fps: default_fps(), + num_inputs: None, + layers: HashMap::new(), + image_overlays: Vec::new(), + text_overlays: Vec::new(), + } + } +} + +// ── Server-computed layout types ───────────────────────────────────────── +// These are emitted via the view data channel so the frontend can render +// overlays / layers at server-computed positions (server is source of truth +// in Monitor view). + +/// Server-computed layout for a single video layer. +#[derive(Serialize, Clone, Debug, PartialEq)] +pub struct ResolvedLayer { + /// Pin name (e.g. "in_0"). + pub id: String, + pub x: i32, + pub y: i32, + pub width: u32, + pub height: u32, + pub opacity: f32, + pub z_index: i32, + pub rotation_degrees: f32, + pub mirror_horizontal: bool, + pub mirror_vertical: bool, +} + +/// Server-computed layout for a single overlay (text or image). +#[derive(Serialize, Clone, Debug, PartialEq)] +pub struct ResolvedOverlay { + pub index: usize, + pub x: i32, + pub y: i32, + /// Resolved width after text wrapping / image aspect-fit. + pub width: u32, + /// Resolved height after text wrapping / image aspect-fit. + pub height: u32, + pub opacity: f32, + pub z_index: i32, + pub rotation_degrees: f32, + pub mirror_horizontal: bool, + pub mirror_vertical: bool, +} + +/// The complete server-computed compositor layout, serialized as view data. +#[derive(Serialize, Clone, Debug, PartialEq)] +pub struct CompositorLayout { + pub canvas_width: u32, + pub canvas_height: u32, + pub layers: SmallVec<[ResolvedLayer; 8]>, + pub text_overlays: SmallVec<[ResolvedOverlay; 8]>, + pub image_overlays: SmallVec<[ResolvedOverlay; 8]>, +} + +impl CompositorConfig { + /// Validate compositor parameters. + /// + /// # Errors + /// + /// Returns an error string if width/height are zero or if opacity values + /// are out of range. + pub fn validate(&self) -> Result<(), String> { + if self.width == 0 || self.height == 0 { + return Err("Canvas width and height must be > 0".to_string()); + } + if self.fps == 0 { + return Err("Output fps must be > 0".to_string()); + } + for (name, layer) in &self.layers { + if !layer.opacity.is_finite() || layer.opacity < 0.0 || layer.opacity > 1.0 { + return Err(format!("Layer '{name}' opacity must be in [0.0, 1.0]")); + } + } + for (i, img) in self.image_overlays.iter().enumerate() { + if !img.transform.opacity.is_finite() + || img.transform.opacity < 0.0 + || img.transform.opacity > 1.0 + { + return Err(format!("Image overlay {i} opacity must be in [0.0, 1.0]")); + } + } + for (i, txt) in self.text_overlays.iter().enumerate() { + if !txt.transform.opacity.is_finite() + || txt.transform.opacity < 0.0 + || txt.transform.opacity > 1.0 + { + return Err(format!("Text overlay {i} opacity must be in [0.0, 1.0]")); + } + } + Ok(()) + } +} diff --git a/crates/nodes/src/video/compositor/kernel.rs b/crates/nodes/src/video/compositor/kernel.rs new file mode 100644 index 00000000..159aab3d --- /dev/null +++ b/crates/nodes/src/video/compositor/kernel.rs @@ -0,0 +1,392 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Compositing kernel — runs on a persistent blocking thread. +//! +//! Contains the data types exchanged between the async node loop and the +//! blocking compositing thread, plus the core `composite_frame` function +//! that blits layers and overlays onto an RGBA8 canvas. + +use std::sync::Arc; +use streamkit_core::types::PixelFormat; + +use super::config::Rect; +use super::overlay::DecodedOverlay; +use super::pixel_ops::{ + all_alpha_opaque, i420_to_rgba8_buf, nv12_to_rgba8_buf, scale_blit_rgba_rotated, +}; + +// ── Compositing kernel (runs on a persistent blocking thread) ──────────────── + +// ── YUV → RGBA conversion cache ───────────────────────────────────────────── + +/// Cached RGBA conversion result for a single layer slot. +struct CachedConversion { + /// Identity of the source data (`Arc::as_ptr` cast to `usize`). + /// When the `Arc` pointer hasn't changed between frames + /// the underlying data is identical and the conversion can be skipped. + data_identity: usize, + width: u32, + height: u32, + /// Pre-converted RGBA8 data, stored as a plain `Vec`. + rgba: Vec, +} + +/// Per-slot cache for YUV → RGBA conversions. +/// +/// Avoids redundant per-frame I420/NV12 → RGBA8 conversion when the source +/// `Arc` hasn't changed since the previous frame. +/// +/// Also caches the first-layer alpha-scan result so that the canvas-clear +/// skip check doesn't re-scan every frame when the source hasn't changed. +pub struct ConversionCache { + entries: Vec>, + /// Cached result of the alpha-opaqueness scan for the first visible layer. + /// `(data_identity, all_opaque)` — valid when the `Arc` pointer matches. + first_layer_alpha_cache: Option<(usize, bool)>, +} + +impl Default for ConversionCache { + fn default() -> Self { + Self::new() + } +} + +impl ConversionCache { + pub const fn new() -> Self { + Self { entries: Vec::new(), first_layer_alpha_cache: None } + } + + /// Check whether the first visible layer's source data is fully opaque. + /// + /// For I420/NV12 layers, the converted RGBA always has alpha == 255, so + /// we return `true` immediately without scanning. For RGBA layers we + /// scan once and cache the result keyed by `Arc::as_ptr`. + fn first_layer_all_opaque(&mut self, layer: &LayerSnapshot, rgba_data: &[u8]) -> bool { + // I420/NV12 → RGBA conversion always writes alpha = 255. + if layer.pixel_format != PixelFormat::Rgba8 { + return true; + } + + let identity = Arc::as_ptr(&layer.data) as usize; + if let Some((cached_id, cached_result)) = self.first_layer_alpha_cache { + if cached_id == identity { + return cached_result; + } + } + + let all_opaque = all_alpha_opaque(rgba_data); + self.first_layer_alpha_cache = Some((identity, all_opaque)); + all_opaque + } + + /// Return a previously-cached RGBA slice for `slot_idx`. + /// + /// # Panics + /// + /// Panics if the slot has not been populated by a prior `get_or_convert` + /// call for the same `layer`. This is only called in the second pass of + /// `composite_frame` after the first pass has ensured every non-RGBA + /// layer has been converted. + fn get_cached(&self, slot_idx: usize, layer: &LayerSnapshot) -> &[u8] { + #[allow(clippy::expect_used)] + let cached = + self.entries[slot_idx].as_ref().expect("get_cached called before get_or_convert"); + let needed = layer.width as usize * layer.height as usize * 4; + &cached.rgba[..needed] + } + + /// Look up or perform a YUV→RGBA conversion for layer at `slot_idx`. + /// Returns a slice of RGBA8 data. + fn get_or_convert(&mut self, slot_idx: usize, layer: &LayerSnapshot) -> &[u8] { + let identity = Arc::as_ptr(&layer.data) as usize; + + // Ensure the cache Vec is large enough. + if self.entries.len() <= slot_idx { + self.entries.resize_with(slot_idx + 1, || None); + } + + // Check if the cached entry is still valid. + let needs_convert = self.entries[slot_idx].as_ref().is_none_or(|cached| { + cached.data_identity != identity + || cached.width != layer.width + || cached.height != layer.height + }); + + if needs_convert { + let needed = layer.width as usize * layer.height as usize * 4; + // Reuse the existing allocation if possible. + let mut rgba = self.entries[slot_idx].take().map(|c| c.rgba).unwrap_or_default(); + if rgba.len() < needed { + rgba.resize(needed, 0); + } + + match layer.pixel_format { + PixelFormat::I420 => { + i420_to_rgba8_buf(layer.data.as_slice(), layer.width, layer.height, &mut rgba); + }, + PixelFormat::Nv12 => { + nv12_to_rgba8_buf(layer.data.as_slice(), layer.width, layer.height, &mut rgba); + }, + PixelFormat::Rgba8 => { + // Should not be called for RGBA, but handle gracefully. + rgba[..needed].copy_from_slice(&layer.data.as_slice()[..needed]); + }, + } + + self.entries[slot_idx] = Some(CachedConversion { + data_identity: identity, + width: layer.width, + height: layer.height, + rgba, + }); + } + + // SAFETY: we just inserted into this slot above when `needs_convert` was true, + // and the slot was already `Some` when `needs_convert` was false. + #[allow(clippy::expect_used)] + let cached = self.entries[slot_idx].as_ref().expect("just inserted"); + let needed = layer.width as usize * layer.height as usize * 4; + &cached.rgba[..needed] + } +} + +/// Snapshot of one input layer's data for the blocking compositor thread. +pub struct LayerSnapshot { + pub data: Arc, + pub width: u32, + pub height: u32, + pub pixel_format: PixelFormat, + pub rect: Option, + pub opacity: f32, + /// Visual stacking order. Retained in the snapshot for diagnostic / + /// logging purposes even though sorting now happens before snapshot + /// construction. + #[allow(dead_code)] + pub z_index: i32, + /// Clockwise rotation in degrees around the destination rect centre. + /// Default `0.0` means no rotation. + pub rotation_degrees: f32, + /// Mirror horizontally (flip left ↔ right). + pub mirror_horizontal: bool, + /// Mirror vertically (flip top ↔ bottom). + pub mirror_vertical: bool, +} + +/// Work item sent from the async loop to the persistent compositing thread. +pub struct CompositeWorkItem { + pub canvas_w: u32, + pub canvas_h: u32, + pub layers: Vec>, + /// Shared, immutable overlay lists. Using `Arc<[…]>` means cloning + /// into the work item each frame is a single ref-count bump instead + /// of cloning the entire `Vec`. + pub image_overlays: Arc<[Arc]>, + pub text_overlays: Arc<[Arc]>, + pub video_pool: Option>, +} + +/// Result sent back from the compositing thread to the async loop. +pub struct CompositeResult { + pub rgba_data: streamkit_core::frame_pool::PooledVideoData, +} + +/// A resolved, ready-to-blit item. Unifies video layers and decoded +/// overlays into a single type for the z-sorted compositing loop. +struct BlitItem<'a> { + src_data: &'a [u8], + src_width: u32, + src_height: u32, + dst_rect: Rect, + opacity: f32, + rotation_degrees: f32, + /// When `true`, all source pixels have alpha == 255. Allows the blit + /// function to skip per-row alpha scanning and always use the memcpy path. + src_opaque: bool, + /// `(z_index, insertion_order)` for stable sorting. + sort_key: (i32, usize), + mirror_horizontal: bool, + mirror_vertical: bool, +} + +/// Composite all layers + overlays onto a fresh RGBA8 canvas buffer. +/// Allocates from the video pool if available. +/// +/// `conversion_cache` caches YUV→RGBA8 conversions across frames so that +/// unchanged layers skip the conversion entirely. +pub fn composite_frame( + canvas_w: u32, + canvas_h: u32, + layers: &[Option], + image_overlays: &[Arc], + text_overlays: &[Arc], + video_pool: Option<&streamkit_core::VideoFramePool>, + conversion_cache: &mut ConversionCache, +) -> streamkit_core::frame_pool::PooledVideoData { + let total_bytes = (canvas_w as usize) * (canvas_h as usize) * 4; + + let mut pooled = video_pool.map_or_else( + || streamkit_core::frame_pool::PooledVideoData::from_vec(vec![0u8; total_bytes]), + |pool| pool.get(total_bytes), + ); + + let buf = pooled.as_mut_slice(); + + // Two-pass source resolution. + // + // Pass 1: populate the conversion cache for every non-RGBA layer. + // `slot_idx` uses the position in the `layers` slice (which preserves + // `None` holes) so that cache indices stay stable even when some slots + // have no frame. + for (slot_idx, entry) in layers.iter().enumerate() { + if let Some(layer) = entry { + if layer.pixel_format != PixelFormat::Rgba8 { + conversion_cache.get_or_convert(slot_idx, layer); + } + } + } + + // Between pass 1 and pass 2: check whether the first layer allows + // skipping the canvas clear. We do the alpha-opaqueness check here + // while `conversion_cache` is still mutably available. The result + // is a simple bool so no borrows leak into pass 2. + let skip_clear = + layers.iter().enumerate().find_map(|(i, e)| e.as_ref().map(|l| (i, l))).is_some_and( + |(_slot_idx, layer)| { + // Quick checks that don't need the pixel data. + if layer.opacity < 1.0 || layer.rotation_degrees.abs() >= 0.01 { + return false; + } + let covers = layer.rect.as_ref().is_none_or(|r| { + r.x <= 0 + && r.y <= 0 + && i64::from(r.width) + i64::from(r.x) >= i64::from(canvas_w) + && i64::from(r.height) + i64::from(r.y) >= i64::from(canvas_h) + }); + if !covers { + return false; + } + // Alpha check — needs mutable access to conversion_cache. + match layer.pixel_format { + // I420/NV12 → RGBA conversion always writes alpha = 255. + PixelFormat::I420 | PixelFormat::Nv12 => true, + PixelFormat::Rgba8 => { + conversion_cache.first_layer_all_opaque(layer, layer.data.as_slice()) + }, + } + }, + ); + if !skip_clear { + buf[..total_bytes].fill(0); + } + + // Pass 2: build resolved references. The mutable borrow of + // `conversion_cache` from pass 1 is released, so we can now take + // shared references into the cache alongside references into `layers`. + let resolved: Vec> = layers + .iter() + .enumerate() + .map(|(slot_idx, entry)| { + entry.as_ref().map(|layer| { + let src_data: &[u8] = match layer.pixel_format { + PixelFormat::Rgba8 => layer.data.as_slice(), + PixelFormat::I420 | PixelFormat::Nv12 => { + // Cache was populated in pass 1; this is a shared + // read that cannot fail. + conversion_cache.get_cached(slot_idx, layer) + }, + }; + (layer, src_data) + }) + }) + .collect(); + + // ── Unified z-sorted blit ───────────────────────────────────────────── + // + // Collect all blittable items (video layers + image/text overlays) into + // a single list, sort by (z_index, insertion_order), then blit in order. + // This replaces the former three separate loops and allows overlays to + // be interleaved with video layers via z_index. + + let mut items: Vec> = Vec::new(); + let mut insertion_order: usize = 0; + + // Video layers. + for (layer, src_data) in resolved.iter().flatten() { + let dst_rect = + layer.rect.clone().unwrap_or(Rect { x: 0, y: 0, width: canvas_w, height: canvas_h }); + // NV12/I420 → RGBA8 conversion always writes alpha = 255. + let src_opaque = layer.pixel_format != PixelFormat::Rgba8; + items.push(BlitItem { + src_data, + src_width: layer.width, + src_height: layer.height, + dst_rect, + opacity: layer.opacity, + rotation_degrees: layer.rotation_degrees, + src_opaque, + sort_key: (layer.z_index, insertion_order), + mirror_horizontal: layer.mirror_horizontal, + mirror_vertical: layer.mirror_vertical, + }); + insertion_order += 1; + } + + // Image overlays. + for ov in image_overlays { + items.push(BlitItem { + src_data: &ov.rgba_data, + src_width: ov.width, + src_height: ov.height, + dst_rect: ov.rect.clone(), + opacity: ov.opacity, + rotation_degrees: ov.rotation_degrees, + src_opaque: false, + sort_key: (ov.z_index, insertion_order), + mirror_horizontal: ov.mirror_horizontal, + mirror_vertical: ov.mirror_vertical, + }); + insertion_order += 1; + } + + // Text overlays. + for ov in text_overlays { + items.push(BlitItem { + src_data: &ov.rgba_data, + src_width: ov.width, + src_height: ov.height, + dst_rect: ov.rect.clone(), + opacity: ov.opacity, + rotation_degrees: ov.rotation_degrees, + src_opaque: false, + sort_key: (ov.z_index, insertion_order), + mirror_horizontal: ov.mirror_horizontal, + mirror_vertical: ov.mirror_vertical, + }); + insertion_order += 1; + } + + // Stable sort: lower z_index drawn first (bottom), ties broken by + // insertion order (video layers first, then image, then text). + items.sort_by_key(|item| item.sort_key); + + for item in &items { + scale_blit_rgba_rotated( + buf, + canvas_w, + canvas_h, + item.src_data, + item.src_width, + item.src_height, + &item.dst_rect, + item.opacity, + item.rotation_degrees, + item.src_opaque, + item.mirror_horizontal, + item.mirror_vertical, + ); + } + + pooled +} diff --git a/crates/nodes/src/video/compositor/mod.rs b/crates/nodes/src/video/compositor/mod.rs new file mode 100644 index 00000000..d7de3d11 --- /dev/null +++ b/crates/nodes/src/video/compositor/mod.rs @@ -0,0 +1,2038 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Video compositor node. +//! +//! Composites multiple raw video inputs onto a single RGBA8 output canvas with +//! optional image and text overlays. Supports dynamic pin creation for +//! attaching arbitrary inputs at runtime. +//! +//! - Inputs accept `RawVideo(RGBA8)` with wildcard dimensions. +//! - Output produces `RawVideo(RGBA8)` at the configured canvas size. +//! - Heavy compositing work runs on a persistent blocking thread (via +//! `spawn_blocking`) to avoid blocking the async runtime and to keep CPU +//! caches warm across frames. +//! - Row-level parallelism via `rayon` for blitting and pixel-format +//! conversion. +//! - Image overlays are decoded once during initialization (PNG/JPEG via the +//! `image` crate). +//! - Text overlays are rasterized via `fontdue` once per `UpdateParams`, not +//! per frame. +//! +//! # Future work +//! - GPU-accelerated compositing via `wgpu`. +//! - Bilinear / Lanczos scaling (MVP uses nearest-neighbor). + +pub mod config; +pub mod kernel; +pub mod overlay; +pub mod pixel_ops; + +use async_trait::async_trait; +use config::{CompositorConfig, CompositorLayout, ResolvedLayer, ResolvedOverlay}; +use kernel::{CompositeResult, CompositeWorkItem, LayerSnapshot}; +use opentelemetry::{global, KeyValue}; +use overlay::{decode_image_overlay, rasterize_text_overlay, DecodedOverlay}; +use schemars::schema_for; +use smallvec::SmallVec; +use std::collections::HashMap; +use std::sync::Arc; +use streamkit_core::control::NodeControlMessage; +use streamkit_core::pins::PinManagementMessage; +use streamkit_core::registry::StaticPins; +use streamkit_core::stats::NodeStatsTracker; +use streamkit_core::types::{ + Packet, PacketMetadata, PacketType, PixelFormat, VideoFormat, VideoFrame, +}; +use streamkit_core::{ + config_helpers, state_helpers, view_data_helpers, InputPin, NodeContext, NodeRegistry, + OutputPin, PinCardinality, ProcessorNode, StreamKitError, +}; +use tokio::sync::mpsc; + +use kernel::{composite_frame, ConversionCache}; + +// ── Input slot ────────────────────────────────────────────────────────────── + +/// Holds a receiver and the most-recently-received frame for one input layer. +struct InputSlot { + name: String, + rx: mpsc::Receiver, + latest_frame: Option, +} + +// ── Cached layer config ───────────────────────────────────────────────────── + +/// Pre-resolved layer configuration for a single slot. +/// Rebuilt only when compositor config or pin set changes, avoiding +/// per-frame `HashMap` lookups and `sort_by` calls. +#[derive(Clone)] +struct ResolvedSlotConfig { + rect: Option, + opacity: f32, + z_index: i32, + rotation_degrees: f32, + /// When `true`, the source is fitted within the destination rect + /// while preserving its aspect ratio (letterbox / pillarbox). + /// Used by auto-PiP layers to avoid stretching. + aspect_fit: bool, + mirror_horizontal: bool, + mirror_vertical: bool, +} + +/// Rebuild the per-slot resolved configs and the z-sorted draw order. +/// +/// Called once at startup and whenever `UpdateParams` or pin management +/// changes the layer set. The returned draw order is a list of slot +/// indices sorted by `(z_index, slot_index)`. +fn rebuild_layer_cache( + slots: &[InputSlot], + config: &CompositorConfig, +) -> (Vec, Vec) { + let num_slots = slots.len(); + let mut configs: Vec = Vec::with_capacity(num_slots); + for (idx, slot) in slots.iter().enumerate() { + let layer_cfg = config.layers.get(&slot.name); + #[allow(clippy::option_if_let_else)] + let (rect, opacity, z_index, rotation_degrees, aspect_fit, mirror_h, mirror_v) = + if let Some(lc) = layer_cfg { + ( + lc.rect.clone(), + lc.opacity, + lc.z_index, + lc.rotation_degrees, + false, + lc.mirror_horizontal, + lc.mirror_vertical, + ) + } else if idx > 0 && num_slots > 1 { + // Auto-PiP: non-first layers without explicit config. + let pip_w = config.width / 3; + let pip_h = config.height / 3; + #[allow(clippy::cast_possible_wrap)] + let pip_x = (config.width - pip_w - 20) as i32; + #[allow(clippy::cast_possible_wrap)] + let pip_y = (config.height - pip_h - 20) as i32; + #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] + ( + Some(config::Rect { x: pip_x, y: pip_y, width: pip_w, height: pip_h }), + 1.0, + idx as i32, + 0.0, + true, // preserve source aspect ratio within PiP bounds + false, + false, + ) + } else { + (None, 1.0, 0, 0.0, false, false, false) + }; + configs.push(ResolvedSlotConfig { + rect, + opacity, + z_index, + rotation_degrees, + aspect_fit, + mirror_horizontal: mirror_h, + mirror_vertical: mirror_v, + }); + } + + // Pre-sort by (z_index, slot_index). + let mut draw_order: Vec = (0..num_slots).collect(); + draw_order.sort_by(|&a, &b| configs[a].z_index.cmp(&configs[b].z_index).then(a.cmp(&b))); + + (configs, draw_order) +} + +/// Compute a destination rect that fits `src_w × src_h` within `bounds` +/// while preserving the source aspect ratio. The fitted rect is centred +/// within the bounds. +fn fit_rect_preserving_aspect(src_w: u32, src_h: u32, bounds: &config::Rect) -> config::Rect { + if src_w == 0 || src_h == 0 || bounds.width == 0 || bounds.height == 0 { + return bounds.clone(); + } + let scale_w = f64::from(bounds.width) / f64::from(src_w); + let scale_h = f64::from(bounds.height) / f64::from(src_h); + let scale = scale_w.min(scale_h); + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let fit_w = (f64::from(src_w) * scale).round() as u32; + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let fit_h = (f64::from(src_h) * scale).round() as u32; + // Centre within the bounding rect. + #[allow(clippy::cast_possible_wrap)] + let offset_x = (bounds.width.saturating_sub(fit_w) / 2) as i32; + #[allow(clippy::cast_possible_wrap)] + let offset_y = (bounds.height.saturating_sub(fit_h) / 2) as i32; + config::Rect { x: bounds.x + offset_x, y: bounds.y + offset_y, width: fit_w, height: fit_h } +} + +// ── Node ──────────────────────────────────────────────────────────────────── + +/// Composites multiple raw video inputs onto a single RGBA8 canvas with +/// optional image/text overlays. +/// +/// Inputs are dynamic (`PinCardinality::Dynamic`) and can be attached at +/// runtime. Each input accepts `RawVideo(RGBA8)` or `RawVideo(I420)` with +/// wildcard dimensions. +/// +/// Output `"out"` always produces `RawVideo(RGBA8)` at the configured canvas +/// size. Downstream nodes (e.g. the VP9 encoder) are responsible for any +/// further format conversion. +pub struct CompositorNode { + config: CompositorConfig, + /// Current input pins (may grow dynamically). + input_pins: Vec, + /// Next input ID for dynamic pin naming. + next_input_id: usize, +} + +impl CompositorNode { + #[must_use] + pub fn new(config: CompositorConfig) -> Self { + let (input_pins, next_input_id) = config.num_inputs.map_or_else( + || { + // Dynamic mode - start with no pins + (Vec::new(), 0) + }, + |num_inputs| { + // Pre-create pins for stateless/oneshot pipelines. + // Follow the YAML convention: single input uses "in", + // multiple inputs use "in_0", "in_1", etc. + let mut pins = Vec::with_capacity(num_inputs); + if num_inputs == 1 { + pins.push(Self::make_input_pin("in".to_string())); + } else { + for i in 0..num_inputs { + pins.push(Self::make_input_pin(format!("in_{i}"))); + } + } + (pins, num_inputs) + }, + ); + + Self { config, input_pins, next_input_id } + } + + /// The set of video packet types accepted by compositor input pins. + fn accepted_video_types() -> Vec { + vec![ + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::Rgba8, + }), + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::I420, + }), + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::Nv12, + }), + ] + } + + /// Returns the definition-time pins for registry (dynamic template). + pub fn definition_pins() -> (Vec, Vec) { + let inputs = vec![InputPin { + name: "in".to_string(), + accepts_types: Self::accepted_video_types(), + cardinality: PinCardinality::Dynamic { prefix: "in".to_string() }, + }]; + + let outputs = vec![OutputPin { + name: "out".to_string(), + produces_type: PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::Rgba8, + }), + cardinality: PinCardinality::Broadcast, + }]; + + (inputs, outputs) + } + + /// Create a concrete `InputPin` for a given name. + fn make_input_pin(name: String) -> InputPin { + InputPin { + name, + accepts_types: Self::accepted_video_types(), + cardinality: PinCardinality::One, + } + } +} + +#[async_trait] +impl ProcessorNode for CompositorNode { + fn input_pins(&self) -> Vec { + self.input_pins.clone() + } + + fn output_pins(&self) -> Vec { + vec![OutputPin { + name: "out".to_string(), + produces_type: PacketType::RawVideo(VideoFormat { + width: Some(self.config.width), + height: Some(self.config.height), + pixel_format: PixelFormat::Rgba8, + }), + cardinality: PinCardinality::Broadcast, + }] + } + + fn supports_dynamic_pins(&self) -> bool { + true + } + + #[allow(clippy::too_many_lines, clippy::cognitive_complexity)] + async fn run(mut self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { + let node_name = context.output_sender.node_name().to_string(); + state_helpers::emit_initializing(&context.state_tx, &node_name); + + tracing::info!( + "CompositorNode starting: {}x{} canvas, {} image overlays, {} text overlays", + self.config.width, + self.config.height, + self.config.image_overlays.len(), + self.config.text_overlays.len(), + ); + + // Decode image overlays (once). Wrap in Arc so per-frame clones + // into the work item are cheap reference-count bumps. + // + // `image_overlay_cfg_indices` records, for each successfully decoded + // overlay, the index of the originating `ImageOverlayConfig` in + // `config.image_overlays`. This allows the cache in + // `apply_update_params` to map decoded bitmaps back to their configs + // without relying on dimension-matching heuristics. + let mut image_overlays_vec: Vec> = + Vec::with_capacity(self.config.image_overlays.len()); + let mut image_overlay_cfg_indices: Vec = + Vec::with_capacity(self.config.image_overlays.len()); + for (i, img_cfg) in self.config.image_overlays.iter().enumerate() { + match decode_image_overlay(img_cfg) { + Ok(overlay) => { + tracing::info!( + "Decoded image overlay {}: {}x{} -> rect ({},{} {}x{})", + i, + overlay.width, + overlay.height, + overlay.rect.x, + overlay.rect.y, + overlay.rect.width, + overlay.rect.height, + ); + image_overlays_vec.push(Arc::new(overlay)); + image_overlay_cfg_indices.push(i); + }, + Err(e) => { + tracing::warn!("Failed to decode image overlay {}: {}", i, e); + }, + } + } + + // Rasterize text overlays (once; re-done on UpdateParams). Also Arc-wrapped. + let mut text_overlays_vec: Vec> = + Vec::with_capacity(self.config.text_overlays.len()); + for txt_cfg in &self.config.text_overlays { + text_overlays_vec.push(Arc::new(rasterize_text_overlay(txt_cfg))); + } + + // Wrap in Arc<[...]> so per-frame clones into the work item are + // a single ref-count bump instead of cloning the entire Vec. + let mut image_overlays: Arc<[Arc]> = Arc::from(image_overlays_vec); + let mut text_overlays: Arc<[Arc]> = Arc::from(text_overlays_vec); + + // Collect initial input slots from pre-connected pins. + let mut slots: Vec = Vec::new(); + for pin_name in context.inputs.keys() { + let pin = Self::make_input_pin(pin_name.clone()); + self.input_pins.push(pin); + // Track next_input_id for dynamically named pins. + if let Some(num_str) = pin_name.strip_prefix("in_") { + if let Ok(n) = num_str.parse::() { + self.next_input_id = self.next_input_id.max(n + 1); + } + } + } + // Drain all pre-connected inputs into slots. + // IMPORTANT: HashMap::drain() has non-deterministic iteration order, + // so we must sort by pin name to ensure stable slot ordering. + // The slot index determines layer stacking (idx 0 = background, + // idx > 0 = auto-PiP), so non-deterministic order would randomly + // swap which input becomes the background vs. the PiP overlay. + let mut pre_inputs: Vec<(String, mpsc::Receiver)> = + context.inputs.drain().collect(); + pre_inputs.sort_by(|(a, _), (b, _)| { + // Sort numerically by the suffix of "in_N" pin names so that + // in_0 < in_1 < ... < in_10. Fall back to lexicographic order + // for non-standard pin names. + let a_num = a.strip_prefix("in_").and_then(|s| s.parse::().ok()); + let b_num = b.strip_prefix("in_").and_then(|s| s.parse::().ok()); + match (a_num, b_num) { + (Some(an), Some(bn)) => an.cmp(&bn), + _ => a.cmp(b), + } + }); + for (name, rx) in pre_inputs { + tracing::info!("CompositorNode: pre-connected input '{}'", name); + slots.push(InputSlot { name, rx, latest_frame: None }); + } + + // Pin management channel (optional). + let mut pin_mgmt_rx = context.pin_management_rx.take(); + + state_helpers::emit_running(&context.state_tx, &node_name); + + let mut stats_tracker = NodeStatsTracker::new(node_name.clone(), context.stats_tx.clone()); + + // Shared state for the compositing thread. + let video_pool = context.video_pool.clone(); + + // ── Persistent compositing thread ─────────────────────────────── + // Instead of spawning a new blocking task per frame, we keep a + // single long-lived thread that processes compositing work items + // sent via a channel. This avoids per-frame thread-pool + // scheduling overhead and keeps CPU caches warm. + let (work_tx, mut work_rx) = tokio::sync::mpsc::channel::(2); + let (result_tx, mut result_rx) = tokio::sync::mpsc::channel::(2); + + let composite_thread = tokio::task::spawn_blocking(move || { + // Per-slot cache for YUV→RGBA conversions. Avoids redundant + // conversion when the source Arc hasn't changed between frames. + let mut conversion_cache = ConversionCache::new(); + + while let Some(work) = work_rx.blocking_recv() { + let rgba_buf = composite_frame( + work.canvas_w, + work.canvas_h, + &work.layers, + &work.image_overlays, + &work.text_overlays, + work.video_pool.as_deref(), + &mut conversion_cache, + ); + let result = CompositeResult { rgba_data: rgba_buf }; + if result_tx.blocking_send(result).is_err() { + break; + } + } + }); + + let mut output_seq: u64 = 0; + let mut stop_reason: &str = "shutdown"; + + // ── OpenTelemetry metrics ─────────────────────────────────────── + let meter = global::meter("skit_nodes"); + let frames_dropped_counter = meter + .u64_counter("compositor.frames_dropped") + .with_description("Frames dropped by the compositor to keep up with real-time input") + .build(); + let otel_attrs = [KeyValue::new("node", node_name.clone())]; + + // ── Fixed-rate tick ────────────────────────────────────────────── + // The compositor runs at a fixed fps regardless of input rates, + // like the audio clocked mixer. On each tick it drains all inputs + // to their latest frame and composites. Inputs that haven't + // delivered a new frame since the last tick reuse their previous + // frame. This guarantees a constant output rate and decouples + // the compositor from input timing. + let tick_duration = + std::time::Duration::from_nanos(1_000_000_000u64 / u64::from(self.config.fps)); + let mut tick = tokio::time::interval(tick_duration); + tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + // ── Cached layer config + draw order ──────────────────────────── + let mut layer_configs_dirty = true; + let mut resolved_configs: Vec = Vec::new(); + let mut sorted_draw_order: Vec = Vec::new(); + + // ── View data (server-driven layout) ──────────────────────────── + let view_data_tx = context.view_data_tx.clone(); + let mut last_layout: Option = None; + + loop { + // ── Wait for the next tick, or handle control / pin msgs ──── + tokio::select! { + biased; + + // Control messages (highest priority). + Some(ctrl_msg) = context.control_rx.recv() => { + match ctrl_msg { + NodeControlMessage::Shutdown => { + tracing::info!("CompositorNode received shutdown"); + break; + }, + NodeControlMessage::UpdateParams(params) => { + let old_fps = self.config.fps; + Self::apply_update_params( + &mut self.config, + &mut image_overlays, + &mut image_overlay_cfg_indices, + &mut text_overlays, + params, + &mut stats_tracker, + ); + layer_configs_dirty = true; + if self.config.fps != old_fps { + let new_duration = std::time::Duration::from_nanos( + 1_000_000_000u64 / u64::from(self.config.fps), + ); + tick = tokio::time::interval(new_duration); + tick.set_missed_tick_behavior( + tokio::time::MissedTickBehavior::Skip, + ); + tracing::info!("Compositor fps changed: {} → {}", old_fps, self.config.fps); + } + }, + NodeControlMessage::Start => {}, + } + continue; + } + + // Pin management. + Some(msg) = async { + match &mut pin_mgmt_rx { + Some(rx) => rx.recv().await, + None => std::future::pending().await, + } + } => { + Self::handle_pin_management( + &mut self, + msg, + &mut slots, + ); + layer_configs_dirty = true; + continue; + } + + // Fixed-rate tick — time to composite. + _ = tick.tick() => {} + } + + // ── Drain each slot to its latest frame (non-blocking) ────── + for slot in &mut slots { + let mut latest: Option = None; + let mut dropped: u64 = 0; + while let Ok(Packet::Video(frame)) = slot.rx.try_recv() { + if latest.is_some() { + dropped += 1; + } + latest = Some(frame); + } + if dropped > 0 { + frames_dropped_counter.add(dropped, &otel_attrs); + stats_tracker.discarded_n(dropped); + } + if let Some(frame) = latest { + slot.latest_frame = Some(frame); + } + } + + // Nothing to composite if no slot has ever received a frame. + if !slots.iter().any(|s| s.latest_frame.is_some()) { + continue; + } + + // Check for closed input channels. + let mut i = 0; + while i < slots.len() { + // A slot whose channel is closed AND has no buffered frame + // can be removed. We detect closure by a failed try_recv + // returning Disconnected — but try_recv above already + // drained. Use a zero-capacity poll instead: + if slots[i].rx.is_closed() { + tracing::info!("CompositorNode: input '{}' closed", slots[i].name); + slots.remove(i); + layer_configs_dirty = true; + } else { + i += 1; + } + } + if slots.is_empty() { + stop_reason = "all_inputs_closed"; + break; + } + + // ── Rebuild layer config cache if needed ───────────────────── + if layer_configs_dirty { + let (cfgs, order) = rebuild_layer_cache(&slots, &self.config); + resolved_configs = cfgs; + sorted_draw_order = order; + layer_configs_dirty = false; + + // Emit layout via view data if it changed. + let layout = Self::build_layout( + &self.config, + &slots, + &resolved_configs, + &image_overlays, + &text_overlays, + ); + if last_layout.as_ref() != Some(&layout) { + if let Ok(json) = serde_json::to_value(&layout) { + view_data_helpers::emit_view_data(&view_data_tx, &node_name, json); + } + last_layout = Some(layout); + } + } + + // ── Send work to persistent compositing thread ───────────── + // Build layer snapshots in pre-sorted draw order using the + // cached per-slot configs (no HashMap lookup, no sort). + let layers: Vec> = sorted_draw_order + .iter() + .map(|&idx| { + slots[idx].latest_frame.as_ref().map(|f| { + let cfg = &resolved_configs[idx]; + let rect = if cfg.aspect_fit { + // Fit the source within the destination rect + // while preserving its aspect ratio. + cfg.rect + .as_ref() + .map(|r| fit_rect_preserving_aspect(f.width, f.height, r)) + } else { + cfg.rect.clone() + }; + LayerSnapshot { + data: f.data.clone(), + width: f.width, + height: f.height, + pixel_format: f.pixel_format, + rect, + opacity: cfg.opacity, + z_index: cfg.z_index, + rotation_degrees: cfg.rotation_degrees, + mirror_horizontal: cfg.mirror_horizontal, + mirror_vertical: cfg.mirror_vertical, + } + }) + }) + .collect(); + + stats_tracker.received(); + + let work_item = CompositeWorkItem { + canvas_w: self.config.width, + canvas_h: self.config.height, + layers, + image_overlays: image_overlays.clone(), + text_overlays: text_overlays.clone(), + video_pool: video_pool.clone(), + }; + + // Send work to the compositing thread. The work channel has + // capacity 2, so at most one item can be in-flight while we + // submit the next. Use try_send to avoid blocking — if the + // compositing thread hasn't finished the previous frame yet, + // drop this one to stay real-time. + match work_tx.try_send(work_item) { + Ok(()) => {}, + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + // Compositing thread is still busy — skip this frame. + frames_dropped_counter.add(1, &otel_attrs); + stats_tracker.discarded(); + continue; + }, + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + tracing::debug!("Compositing thread gone, stopping CompositorNode"); + stop_reason = "compositor_thread_gone"; + break; + }, + } + + let Some(composite_result) = result_rx.recv().await else { + tracing::debug!("Compositing result channel closed"); + stop_reason = "compositor_thread_gone"; + break; + }; + + // Build metadata from the first available input frame. + let src_metadata = + slots.iter().find_map(|s| s.latest_frame.as_ref()).and_then(|f| f.metadata.clone()); + + let metadata = Some(PacketMetadata { + timestamp_us: src_metadata.as_ref().and_then(|m| m.timestamp_us), + duration_us: src_metadata.as_ref().and_then(|m| m.duration_us), + sequence: Some(output_seq), + // Don't set keyframe — the compositor outputs raw RGBA, not + // encoded video. Downstream encoders (VP9) decide their own + // keyframe placement via kf_max_dist. Setting this to true + // caused every frame to be force-keyframed via VPX_EFLAG_FORCE_KF, + // creating one MoQ group per frame and overwhelming the browser. + keyframe: None, + }); + + let out_frame = VideoFrame::from_pooled( + self.config.width, + self.config.height, + PixelFormat::Rgba8, + composite_result.rgba_data, + metadata, + )?; + + // Non-blocking output send — if downstream (VP9 encoder) is + // backed up, drop the frame rather than stalling the + // compositor loop. ChannelClosed is permanent (downstream + // gone), so we stop the node. + match context.output_sender.try_send("out", Packet::Video(out_frame)) { + Ok(()) => {}, + Err(streamkit_core::node::OutputSendError::ChannelFull { .. }) => { + frames_dropped_counter.add(1, &otel_attrs); + stats_tracker.discarded(); + output_seq += 1; + continue; + }, + Err(_) => { + tracing::debug!("Output channel closed, stopping CompositorNode"); + stop_reason = "output_closed"; + break; + }, + } + + stats_tracker.sent(); + stats_tracker.maybe_send(); + output_seq += 1; + } + + // Drop the work sender to signal the compositing thread to exit. + // NOTE: Any composite result currently in-flight (sent to the thread + // but not yet received back via result_rx) will be lost here. This is + // acceptable for shutdown semantics — we prefer a fast exit over + // draining one extra frame that may never be forwarded downstream. + drop(work_tx); + let _ = composite_thread.await; + + stats_tracker.force_send(); + state_helpers::emit_stopped(&context.state_tx, &node_name, stop_reason); + Ok(()) + } +} + +// ── Private helpers on CompositorNode ─────────────────────────────────────── + +impl CompositorNode { + /// Build a `CompositorLayout` from the current resolved state. + /// + /// This captures the server-computed positions and dimensions for all + /// layers and overlays, which the frontend uses as the source of truth + /// in Monitor view. + fn build_layout( + config: &CompositorConfig, + slots: &[InputSlot], + resolved_configs: &[ResolvedSlotConfig], + image_overlays: &Arc<[Arc]>, + text_overlays: &Arc<[Arc]>, + ) -> CompositorLayout { + let mut layers: SmallVec<[ResolvedLayer; 8]> = SmallVec::new(); + for (idx, slot) in slots.iter().enumerate() { + if let Some(cfg) = resolved_configs.get(idx) { + let (x, y, width, height) = if cfg.aspect_fit { + // When aspect_fit is enabled (e.g. auto-PiP), emit the + // fitted rect that matches what the compositor actually + // renders, not the raw bounding box. This prevents the + // frontend from displaying a stretched layer rectangle. + match (cfg.rect.as_ref(), slot.latest_frame.as_ref()) { + (Some(rect), Some(frame)) => { + let fitted = + fit_rect_preserving_aspect(frame.width, frame.height, rect); + (fitted.x, fitted.y, fitted.width, fitted.height) + }, + (Some(rect), None) => (rect.x, rect.y, rect.width, rect.height), + _ => (0, 0, config.width, config.height), + } + } else { + cfg.rect.as_ref().map_or((0, 0, config.width, config.height), |rect| { + (rect.x, rect.y, rect.width, rect.height) + }) + }; + layers.push(ResolvedLayer { + id: slot.name.clone(), + x, + y, + width, + height, + opacity: cfg.opacity, + z_index: cfg.z_index, + rotation_degrees: cfg.rotation_degrees, + mirror_horizontal: cfg.mirror_horizontal, + mirror_vertical: cfg.mirror_vertical, + }); + } + } + + let mut resolved_image_overlays: SmallVec<[ResolvedOverlay; 8]> = SmallVec::new(); + for (i, ov) in image_overlays.iter().enumerate() { + resolved_image_overlays.push(ResolvedOverlay { + index: i, + x: ov.rect.x, + y: ov.rect.y, + width: ov.rect.width, + height: ov.rect.height, + opacity: ov.opacity, + z_index: ov.z_index, + rotation_degrees: ov.rotation_degrees, + mirror_horizontal: ov.mirror_horizontal, + mirror_vertical: ov.mirror_vertical, + }); + } + + let mut resolved_text_overlays: SmallVec<[ResolvedOverlay; 8]> = SmallVec::new(); + for (i, ov) in text_overlays.iter().enumerate() { + resolved_text_overlays.push(ResolvedOverlay { + index: i, + x: ov.rect.x, + y: ov.rect.y, + width: ov.rect.width, + height: ov.rect.height, + opacity: ov.opacity, + z_index: ov.z_index, + rotation_degrees: ov.rotation_degrees, + mirror_horizontal: ov.mirror_horizontal, + mirror_vertical: ov.mirror_vertical, + }); + } + + CompositorLayout { + canvas_width: config.width, + canvas_height: config.height, + layers, + text_overlays: resolved_text_overlays, + image_overlays: resolved_image_overlays, + } + } + + fn apply_update_params( + config: &mut CompositorConfig, + image_overlays: &mut Arc<[Arc]>, + image_overlay_cfg_indices: &mut Vec, + text_overlays: &mut Arc<[Arc]>, + params: serde_json::Value, + stats_tracker: &mut NodeStatsTracker, + ) { + match serde_json::from_value::(params) { + Ok(new_config) => match new_config.validate() { + Ok(()) => { + tracing::info!( + old_w = config.width, + old_h = config.height, + new_w = new_config.width, + new_h = new_config.height, + "Updating compositor config" + ); + + // Re-decode image overlays only when their content or + // target rect changed. When only video-layer positions + // are updated (the common case) the existing decoded + // bitmaps are reused via Arc, avoiding redundant base64 + // decode + bilinear prescale work. + // + // The cache is keyed by (data_base64, width, height). + // `image_overlay_cfg_indices` provides an exact mapping + // from each decoded overlay back to its originating + // config index, eliminating any heuristic guessing + // about which decoded bitmap belongs to which config. + let old_imgs = image_overlays.clone(); + let old_cfgs = &config.image_overlays; + + let mut cache: HashMap<(&str, u32, u32), Vec>> = + HashMap::new(); + + // Each decoded overlay has a recorded config index in + // `image_overlay_cfg_indices`. Use this to look up + // the originating config directly — no dimension + // matching needed. + for (dec_idx, decoded) in old_imgs.iter().enumerate() { + if let Some(&cfg_idx) = image_overlay_cfg_indices.get(dec_idx) { + if let Some(old_cfg) = old_cfgs.get(cfg_idx) { + let key = ( + old_cfg.data_base64.as_str(), + old_cfg.transform.rect.width, + old_cfg.transform.rect.height, + ); + cache.entry(key).or_default().push(Arc::clone(decoded)); + } + } + } + + let mut new_image_overlays: Vec> = + Vec::with_capacity(new_config.image_overlays.len()); + let mut new_cfg_indices: Vec = + Vec::with_capacity(new_config.image_overlays.len()); + for (new_idx, img_cfg) in new_config.image_overlays.iter().enumerate() { + let key = ( + img_cfg.data_base64.as_str(), + img_cfg.transform.rect.width, + img_cfg.transform.rect.height, + ); + if let Some(entries) = cache.get_mut(&key) { + if let Some(existing) = entries.pop() { + // Content and target dimensions unchanged — + // reuse the decoded bitmap. The overlay's + // rect may be smaller than the config rect + // due to aspect-ratio-preserving prescale, + // so re-centre within the new config rect. + let mut ov = (*existing).clone(); + let cfg_w = img_cfg.transform.rect.width.cast_signed(); + let cfg_h = img_cfg.transform.rect.height.cast_signed(); + let ov_w = ov.rect.width.cast_signed(); + let ov_h = ov.rect.height.cast_signed(); + ov.rect.x = img_cfg.transform.rect.x + (cfg_w - ov_w) / 2; + ov.rect.y = img_cfg.transform.rect.y + (cfg_h - ov_h) / 2; + ov.opacity = img_cfg.transform.opacity; + ov.rotation_degrees = img_cfg.transform.rotation_degrees; + ov.z_index = img_cfg.transform.z_index; + ov.mirror_horizontal = img_cfg.transform.mirror_horizontal; + ov.mirror_vertical = img_cfg.transform.mirror_vertical; + new_image_overlays.push(Arc::new(ov)); + new_cfg_indices.push(new_idx); + continue; + } + } + match decode_image_overlay(img_cfg) { + Ok(ov) => { + new_image_overlays.push(Arc::new(ov)); + new_cfg_indices.push(new_idx); + }, + Err(e) => tracing::warn!("Image overlay decode failed: {e}"), + } + } + *image_overlays = Arc::from(new_image_overlays); + *image_overlay_cfg_indices = new_cfg_indices; + + // Re-rasterize text overlays. + let new_text_overlays: Vec> = new_config + .text_overlays + .iter() + .map(|txt_cfg| Arc::new(rasterize_text_overlay(txt_cfg))) + .collect(); + *text_overlays = Arc::from(new_text_overlays); + + *config = new_config; + }, + Err(e) => { + tracing::warn!("Rejected invalid compositor config: {e}"); + stats_tracker.errored(); + }, + }, + Err(e) => { + tracing::warn!("Failed to deserialize compositor UpdateParams: {e}"); + stats_tracker.errored(); + }, + } + } + + fn handle_pin_management( + node: &mut Box, + msg: PinManagementMessage, + slots: &mut Vec, + ) { + match msg { + PinManagementMessage::RequestAddInputPin { suggested_name, response_tx } => { + let pin_name = suggested_name.unwrap_or_else(|| { + let name = format!("in_{}", node.next_input_id); + node.next_input_id += 1; + name + }); + let pin = Self::make_input_pin(pin_name); + node.input_pins.push(pin.clone()); + let _ = response_tx.send(Ok(pin)); + }, + PinManagementMessage::AddedInputPin { pin, channel } => { + tracing::info!("CompositorNode: activated input pin '{}'", pin.name); + slots.push(InputSlot { name: pin.name, rx: channel, latest_frame: None }); + }, + PinManagementMessage::RemoveInputPin { pin_name } => { + tracing::info!("CompositorNode: removed input pin '{}'", pin_name); + slots.retain(|s| s.name != pin_name); + node.input_pins.retain(|p| p.name != pin_name); + }, + _ => {}, + } + } +} + +// ── Registration ──────────────────────────────────────────────────────────── + +#[allow(clippy::expect_used, clippy::missing_panics_doc)] +pub fn register_compositor_nodes(registry: &mut NodeRegistry) { + let (def_inputs, def_outputs) = CompositorNode::definition_pins(); + + registry.register_static_with_description( + "video::compositor", + |params| { + let config: CompositorConfig = config_helpers::parse_config_optional(params)?; + if let Err(e) = config.validate() { + return Err(StreamKitError::Configuration(e)); + } + Ok(Box::new(CompositorNode::new(config))) + }, + serde_json::to_value(schema_for!(CompositorConfig)) + .expect("CompositorConfig schema should serialize to JSON"), + StaticPins { inputs: def_inputs, outputs: def_outputs }, + vec!["video".to_string(), "compositing".to_string()], + false, + "Composites multiple raw video inputs (RGBA8) onto a single canvas with \ + image and text overlays. Supports dynamic pin creation for attaching \ + arbitrary inputs at runtime.", + ); +} + +// ── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +#[allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::cast_possible_truncation, + clippy::cast_sign_loss +)] +mod tests { + use super::*; + use crate::test_utils::{ + assert_state_initializing, assert_state_running, assert_state_stopped, create_test_context, + }; + use config::{LayerConfig, Rect}; + use pixel_ops::{scale_blit_rgba, scale_blit_rgba_rotated}; + use std::collections::HashMap; + use tokio::sync::mpsc; + + /// Create a solid-colour RGBA8 VideoFrame. + fn make_rgba_frame(width: u32, height: u32, r: u8, g: u8, b: u8, a: u8) -> VideoFrame { + let total = (width as usize) * (height as usize) * 4; + let mut data = vec![0u8; total]; + for pixel in data.chunks_exact_mut(4) { + pixel[0] = r; + pixel[1] = g; + pixel[2] = b; + pixel[3] = a; + } + VideoFrame::new(width, height, PixelFormat::Rgba8, data).unwrap() + } + + // ── Unit tests for compositing helpers ─────────────────────────────── + + #[test] + fn test_scale_blit_identity() { + // 2x2 red source blitted onto a 4x4 canvas at (1,1) 2x2 rect. + let src = vec![255, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 128, 128, 128, 255]; + let mut dst = vec![0u8; 4 * 4 * 4]; // 4x4 RGBA, all transparent black + + scale_blit_rgba( + &mut dst, + 4, + 4, + &src, + 2, + 2, + &Rect { x: 1, y: 1, width: 2, height: 2 }, + 1.0, + false, + false, + false, + ); + + // Pixel at (1,1) should be red. + let x = 1usize; + let y = 1usize; + let idx = (y * 4 + x) * 4; + assert_eq!(dst[idx], 255); + assert_eq!(dst[idx + 1], 0); + assert_eq!(dst[idx + 2], 0); + assert_eq!(dst[idx + 3], 255); + + // Pixel at (0,0) should remain transparent black. + assert_eq!(dst[0], 0); + assert_eq!(dst[3], 0); + } + + #[test] + fn test_scale_blit_with_opacity() { + // White source at 50% opacity over black background. + let src = vec![255, 255, 255, 255]; // 1x1 white + let mut dst = vec![0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255]; // 2x2 black + + scale_blit_rgba( + &mut dst, + 2, + 2, + &src, + 1, + 1, + &Rect { x: 0, y: 0, width: 1, height: 1 }, + 0.5, + false, + false, + false, + ); + + // Pixel (0,0): white at 50% over opaque black -> ~128 grey. + let r = dst[0]; + assert!(r > 120 && r < 135, "Expected ~128, got {r}"); + } + + #[test] + fn test_scale_blit_scaling() { + // 1x1 red source scaled to 4x4 rect on an 8x8 canvas. + let src = vec![255, 0, 0, 255]; + let mut dst = vec![0u8; 8 * 8 * 4]; + + scale_blit_rgba( + &mut dst, + 8, + 8, + &src, + 1, + 1, + &Rect { x: 2, y: 2, width: 4, height: 4 }, + 1.0, + false, + false, + false, + ); + + // All pixels in the 4x4 destination rect should be red. + for y in 2..6u32 { + for x in 2..6u32 { + let idx = ((y * 8 + x) * 4) as usize; + assert_eq!(dst[idx], 255, "Red at ({x},{y})"); + assert_eq!(dst[idx + 1], 0, "Green at ({x},{y})"); + } + } + // Outside should remain black. + assert_eq!(dst[0], 0); + } + + #[test] + fn test_rotated_blit_stretch_to_fill() { + // A wide 4×2 red source blitted into a square 20×20 rect with 45° + // rotation on a 40×40 canvas. + // + // The source is stretched to fill the 20×20 rect (no aspect-ratio + // fit), then rotated 45°. The centre of the rect (canvas pixel + // 20,20) should be covered by red source pixels, while the rect + // corner (10,10) — outside the rotated area — should remain + // transparent. + let src = [255u8, 0, 0, 255].repeat(4 * 2); // 4×2 solid red + let mut dst = vec![0u8; 40 * 40 * 4]; + + scale_blit_rgba_rotated( + &mut dst, + 40, + 40, + &src, + 4, + 2, + &Rect { x: 10, y: 10, width: 20, height: 20 }, + 1.0, + 45.0, + false, + false, + false, + ); + + // The centre of the rect (canvas pixel 20,20) should be covered + // by source content (red). + let cx = 20usize; + let cy = 20usize; + let idx = (cy * 40 + cx) * 4; + assert_eq!(dst[idx], 255, "Centre R"); + assert_eq!(dst[idx + 1], 0, "Centre G"); + assert_eq!(dst[idx + 2], 0, "Centre B"); + assert!(dst[idx + 3] > 200, "Centre A should be mostly opaque"); + + // The rect corner (10,10) is outside the rotated content area + // and should remain transparent. + let corner_idx = (10usize * 40 + 10) * 4; + assert_eq!(dst[corner_idx + 3], 0, "Rect corner should be transparent"); + } + + #[test] + fn test_composite_frame_empty_layers() { + // No layers, no overlays -> transparent black canvas. + let mut cache = ConversionCache::new(); + let result = composite_frame(4, 4, &[], &[], &[], None, &mut cache); + let buf = result.as_slice(); + assert_eq!(buf.len(), 4 * 4 * 4); + assert!(buf.iter().all(|&b| b == 0)); + } + + #[test] + fn test_composite_frame_single_layer() { + let data = make_rgba_frame(2, 2, 255, 0, 0, 255); + let layer = LayerSnapshot { + data: data.data, + width: 2, + height: 2, + pixel_format: PixelFormat::Rgba8, + rect: Some(Rect { x: 0, y: 0, width: 4, height: 4 }), + opacity: 1.0, + z_index: 0, + rotation_degrees: 0.0, + mirror_horizontal: false, + mirror_vertical: false, + }; + + let mut cache = ConversionCache::new(); + let result = composite_frame(4, 4, &[Some(layer)], &[], &[], None, &mut cache); + let buf = result.as_slice(); + + // Entire canvas should be red (scaled from 2x2 to 4x4). + for pixel in buf.chunks_exact(4) { + assert_eq!(pixel[0], 255, "Red channel"); + assert_eq!(pixel[1], 0, "Green channel"); + assert_eq!(pixel[2], 0, "Blue channel"); + assert_eq!(pixel[3], 255, "Alpha channel"); + } + } + + #[test] + fn test_composite_frame_two_layers() { + // Bottom: full-canvas red. Top: small green square at (1,1) 2x2. + let red = make_rgba_frame(4, 4, 255, 0, 0, 255); + let green = make_rgba_frame(2, 2, 0, 255, 0, 255); + + let layer0 = LayerSnapshot { + data: red.data, + width: 4, + height: 4, + pixel_format: PixelFormat::Rgba8, + rect: None, + opacity: 1.0, + z_index: 0, + rotation_degrees: 0.0, + mirror_horizontal: false, + mirror_vertical: false, + }; + let layer1 = LayerSnapshot { + data: green.data, + width: 2, + height: 2, + pixel_format: PixelFormat::Rgba8, + rect: Some(Rect { x: 1, y: 1, width: 2, height: 2 }), + opacity: 1.0, + z_index: 1, + rotation_degrees: 0.0, + mirror_horizontal: false, + mirror_vertical: false, + }; + + let mut cache = ConversionCache::new(); + let result = + composite_frame(4, 4, &[Some(layer0), Some(layer1)], &[], &[], None, &mut cache); + let buf = result.as_slice(); + + // (0,0) should be red. + assert_eq!(buf[0], 255); + assert_eq!(buf[1], 0); + + // (1,1) should be green (overwritten by top layer). + let x = 1usize; + let y = 1usize; + let idx = (y * 4 + x) * 4; + assert_eq!(buf[idx], 0); + assert_eq!(buf[idx + 1], 255); + assert_eq!(buf[idx + 2], 0); + } + + #[test] + fn test_rasterize_text_overlay_produces_pixels() { + let cfg = config::TextOverlayConfig { + text: "Hi".to_string(), + transform: config::OverlayTransform { + rect: Rect { x: 0, y: 0, width: 64, height: 32 }, + opacity: 1.0, + rotation_degrees: 0.0, + z_index: 0, + mirror_horizontal: false, + mirror_vertical: false, + }, + color: [255, 255, 0, 255], + font_size: 24, + font_path: None, + font_data_base64: None, + font_name: None, + }; + let overlay = rasterize_text_overlay(&cfg); + // Width and height should be at least the original rect dimensions. + assert!(overlay.width >= 64); + assert!(overlay.height >= 32); + // The rect in the returned overlay should match the bitmap dimensions. + assert_eq!(overlay.rect.width, overlay.width); + assert_eq!(overlay.rect.height, overlay.height); + // Should have some non-zero pixels (text was drawn). + assert!(overlay.rgba_data.iter().any(|&b| b > 0)); + } + + #[test] + fn test_fit_rect_preserving_aspect() { + // 4:3 source into 16:9 bounds → pillarboxed (width-limited) + let bounds = Rect { x: 100, y: 50, width: 426, height: 240 }; + let fitted = fit_rect_preserving_aspect(640, 480, &bounds); + // Scale = min(426/640, 240/480) = min(0.666, 0.5) = 0.5 + // Fitted: 320×240, centred within 426×240 + assert_eq!(fitted.width, 320); + assert_eq!(fitted.height, 240); + assert_eq!(fitted.x, 100 + (426 - 320) / 2); + assert_eq!(fitted.y, 50); + + // 16:9 source into 4:3 bounds → letterboxed (height-limited) + let bounds = Rect { x: 0, y: 0, width: 400, height: 400 }; + let fitted = fit_rect_preserving_aspect(1280, 720, &bounds); + // Scale = min(400/1280, 400/720) = min(0.3125, 0.555) = 0.3125 + // Fitted: 400×225, centred within 400×400 + assert_eq!(fitted.width, 400); + assert_eq!(fitted.height, 225); + assert_eq!(fitted.x, 0); + assert_eq!(fitted.y, (400 - 225) / 2); + + // Exact match → no change + let bounds = Rect { x: 10, y: 20, width: 640, height: 480 }; + let fitted = fit_rect_preserving_aspect(640, 480, &bounds); + assert_eq!(fitted.width, 640); + assert_eq!(fitted.height, 480); + assert_eq!(fitted.x, 10); + assert_eq!(fitted.y, 20); + } + + #[test] + fn test_config_validate_ok() { + let cfg = CompositorConfig::default(); + assert!(cfg.validate().is_ok()); + } + + #[test] + fn test_config_validate_zero_dimensions() { + let cfg = CompositorConfig { width: 0, height: 720, ..Default::default() }; + assert!(cfg.validate().is_err()); + } + + #[test] + fn test_config_validate_bad_opacity() { + let mut cfg = CompositorConfig::default(); + cfg.layers.insert("in_0".to_string(), LayerConfig { opacity: 1.5, ..Default::default() }); + assert!(cfg.validate().is_err()); + } + + // ── Integration test: node run() with mock context ────────────────── + + #[tokio::test] + async fn test_compositor_node_run_main_only() { + let (input_tx, input_rx) = mpsc::channel(10); + let mut inputs = HashMap::new(); + inputs.insert("in_0".to_string(), input_rx); + + let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + + let config = CompositorConfig { width: 4, height: 4, ..Default::default() }; + let node = CompositorNode::new(config); + + let node_handle = tokio::spawn(async move { Box::new(node).run(context).await }); + + assert_state_initializing(&mut state_rx).await; + assert_state_running(&mut state_rx).await; + + // Send a red frame. + let frame = make_rgba_frame(2, 2, 255, 0, 0, 255); + input_tx.send(Packet::Video(frame)).await.unwrap(); + + // Give time for processing. + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Close input. + drop(input_tx); + + assert_state_stopped(&mut state_rx).await; + node_handle.await.unwrap().unwrap(); + + let output_packets = mock_sender.get_packets_for_pin("out").await; + assert!(!output_packets.is_empty(), "Expected at least 1 output frame"); + + // Verify output is 4x4 RGBA. + if let Packet::Video(ref out_frame) = output_packets[0] { + assert_eq!(out_frame.width, 4); + assert_eq!(out_frame.height, 4); + assert_eq!(out_frame.pixel_format, PixelFormat::Rgba8); + // Should be red (2x2 scaled to fill 4x4). + assert_eq!(out_frame.data()[0], 255); // R + assert_eq!(out_frame.data()[1], 0); // G + } else { + panic!("Expected video packet"); + } + } + + #[tokio::test] + async fn test_compositor_node_preserves_metadata() { + let (input_tx, input_rx) = mpsc::channel(10); + let mut inputs = HashMap::new(); + inputs.insert("in_0".to_string(), input_rx); + + let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + + let config = CompositorConfig { width: 2, height: 2, ..Default::default() }; + let node = CompositorNode::new(config); + + let node_handle = tokio::spawn(async move { Box::new(node).run(context).await }); + + assert_state_initializing(&mut state_rx).await; + assert_state_running(&mut state_rx).await; + + let mut frame = make_rgba_frame(2, 2, 100, 100, 100, 255); + frame.metadata = Some(PacketMetadata { + timestamp_us: Some(42_000), + duration_us: Some(33_333), + sequence: Some(7), + keyframe: Some(true), + }); + input_tx.send(Packet::Video(frame)).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + drop(input_tx); + + assert_state_stopped(&mut state_rx).await; + node_handle.await.unwrap().unwrap(); + + let output_packets = mock_sender.get_packets_for_pin("out").await; + assert!(!output_packets.is_empty()); + + if let Packet::Video(ref out_frame) = output_packets[0] { + let meta = out_frame.metadata.as_ref().expect("metadata should be preserved"); + assert_eq!(meta.timestamp_us, Some(42_000)); + assert_eq!(meta.duration_us, Some(33_333)); + assert_eq!(meta.sequence, Some(0)); // output sequence starts at 0 + } else { + panic!("Expected video packet"); + } + } + + #[test] + fn test_compositor_definition_pins() { + let (inputs, outputs) = CompositorNode::definition_pins(); + assert_eq!(inputs.len(), 1); + assert_eq!(inputs[0].name, "in"); + assert!(matches!(inputs[0].cardinality, PinCardinality::Dynamic { .. })); + assert_eq!(outputs.len(), 1); + assert_eq!(outputs[0].name, "out"); + } + + #[test] + fn test_compositor_pool_usage() { + use streamkit_core::frame_pool::FramePool; + + let canvas_w = 4u32; + let canvas_h = 4u32; + let total = (canvas_w as usize) * (canvas_h as usize) * 4; // 64 bytes + + let pool = FramePool::::preallocated(&[total], 2); + assert_eq!(pool.stats().buckets[0].available, 2); + + let mut cache = ConversionCache::new(); + let result = composite_frame(canvas_w, canvas_h, &[], &[], &[], Some(&pool), &mut cache); + assert_eq!(result.as_slice().len(), total); + // One buffer was taken from the pool. + assert_eq!(pool.stats().buckets[0].available, 1); + + // Drop returns to pool. + drop(result); + assert_eq!(pool.stats().buckets[0].available, 2); + } + + // ── SIMD vs scalar equivalence tests ──────────────────────────────── + + /// Helper: scalar I420→RGBA8 conversion for a single pixel (reference). + #[allow(clippy::many_single_char_names)] + fn scalar_i420_to_rgba8(y: u8, u: u8, v: u8) -> [u8; 4] { + let c = i32::from(y) - 16; + let d = i32::from(u) - 128; + let e = i32::from(v) - 128; + let r = ((298 * c + 409 * e + 128) >> 8).clamp(0, 255) as u8; + let g = ((298 * c - 100 * d - 208 * e + 128) >> 8).clamp(0, 255) as u8; + let b = ((298 * c + 516 * d + 128) >> 8).clamp(0, 255) as u8; + [r, g, b, 255] + } + + /// Helper: scalar RGBA8→Y for a single pixel (reference). + fn scalar_rgba8_to_y(r: u8, g: u8, b: u8) -> u8 { + let y = ((66 * i32::from(r) + 129 * i32::from(g) + 25 * i32::from(b) + 128) >> 8) + 16; + y.clamp(0, 255) as u8 + } + + #[test] + fn test_i420_to_rgba8_simd_matches_scalar() { + // Test a variety of YUV values, including edge cases that trigger + // i16 overflow with the BT.601 coefficients. + let test_cases: Vec<(u8, u8, u8)> = vec![ + (16, 128, 128), // black + (235, 128, 128), // white + (81, 90, 240), // pure red + (145, 54, 34), // pure green + (41, 240, 110), // pure blue + (255, 128, 128), // max Y + (0, 0, 0), // min everything + (255, 255, 255), // max everything + (16, 0, 255), // extreme chroma + (235, 255, 0), // extreme chroma + ]; + + let width = test_cases.len() as u32; + // Build I420 buffer. + let mut y_plane = Vec::new(); + let mut u_plane = Vec::new(); + let mut v_plane = Vec::new(); + for &(y, u, v) in &test_cases { + y_plane.push(y); + // Each chroma sample covers 2 luma pixels horizontally. + if y_plane.len() % 2 == 1 { + u_plane.push(u); + v_plane.push(v); + } + } + let chroma_w = (width as usize).div_ceil(2); + // Pad if needed. + while u_plane.len() < chroma_w { + u_plane.push(128); + v_plane.push(128); + } + + let mut i420_data = Vec::new(); + i420_data.extend_from_slice(&y_plane); + i420_data.extend_from_slice(&u_plane); + i420_data.extend_from_slice(&v_plane); + + // Convert using the public function (which uses SIMD internally). + let mut simd_out = vec![0u8; width as usize * 4]; + pixel_ops::i420_to_rgba8_buf(&i420_data, width, 1, &mut simd_out); + + // Compare with scalar reference. + for (i, &(y, _u, _v)) in test_cases.iter().enumerate() { + // For chroma, each sample covers 2 pixels, so use the chroma + // value from the corresponding pair. + let chroma_idx = i / 2; + let actual_u = u_plane[chroma_idx]; + let actual_v = v_plane[chroma_idx]; + let expected = scalar_i420_to_rgba8(y, actual_u, actual_v); + let got = &simd_out[i * 4..(i + 1) * 4]; + assert_eq!( + got, &expected, + "pixel {i}: Y={y} U={actual_u} V={actual_v} → expected {expected:?}, got {got:?}" + ); + } + } + + #[test] + fn test_rgba8_to_i420_simd_matches_scalar() { + // Test RGBA→Y conversion with values that trigger i16 overflow + // (129 * 255 = 32895 > i16::MAX). + let test_pixels: Vec<(u8, u8, u8)> = vec![ + (0, 0, 0), // black + (255, 255, 255), // white + (255, 0, 0), // red + (0, 255, 0), // green + (0, 0, 255), // blue + (128, 128, 128), // mid grey + (0, 254, 0), // just below overflow threshold + (0, 255, 0), // at overflow threshold + ]; + + let width = test_pixels.len() as u32; + let mut rgba_data = Vec::with_capacity(width as usize * 4); + for &(r, g, b) in &test_pixels { + rgba_data.extend_from_slice(&[r, g, b, 255]); + } + + // Convert using the public function (SIMD internally). + let i420_size = width as usize + 2 * (width as usize).div_ceil(2); + let mut i420_out = vec![0u8; i420_size]; + pixel_ops::rgba8_to_i420_buf(&rgba_data, width, 1, &mut i420_out); + + // Check Y plane matches scalar. + for (i, &(r, g, b)) in test_pixels.iter().enumerate() { + let expected_y = scalar_rgba8_to_y(r, g, b); + let got_y = i420_out[i]; + assert_eq!( + got_y, expected_y, + "pixel {i}: R={r} G={g} B={b} → expected Y={expected_y}, got Y={got_y}" + ); + } + } + + #[test] + fn test_i420_rgba8_roundtrip_preserves_values() { + // A full I420→RGBA8→I420 round-trip should produce values close + // to the originals (within ±2 due to integer rounding). + let width: u32 = 8; + let height: u32 = 2; + let w = width as usize; + let h = height as usize; + let chroma_w = w.div_ceil(2); + + // Build a simple I420 test pattern. + let mut i420_data = vec![0u8; w * h + 2 * chroma_w * (h / 2)]; + // Y plane: gradient. + for (i, val) in i420_data[..w * h].iter_mut().enumerate() { + *val = (16 + (i * 219 / (w * h))) as u8; + } + // U/V planes: mid-range. + let u_offset = w * h; + let v_offset = u_offset + chroma_w * (h / 2); + for i in 0..chroma_w * (h / 2) { + i420_data[u_offset + i] = 128; + i420_data[v_offset + i] = 128; + } + + // I420 → RGBA8 → I420 + let mut rgba = vec![0u8; w * h * 4]; + pixel_ops::i420_to_rgba8_buf(&i420_data, width, height, &mut rgba); + let mut i420_roundtrip = vec![0u8; i420_data.len()]; + pixel_ops::rgba8_to_i420_buf(&rgba, width, height, &mut i420_roundtrip); + + // Y values should be close (within ±2 of originals due to rounding). + for (idx, orig_val) in i420_data[..w * h].iter().enumerate() { + let orig = i32::from(*orig_val); + let rt = i32::from(i420_roundtrip[idx]); + assert!( + (orig - rt).abs() <= 2, + "Y[{idx}]: original={orig}, roundtrip={rt}, diff={}", + (orig - rt).abs() + ); + } + } + + /// Test that `scale_blit_rgba` with opacity < 1.0 writes all rows correctly + /// on a buffer wide enough to exercise the AVX2 blend path (32 pixels). + /// This verifies the AVX2 → SSE2 → scalar cascade in `blit_row_alpha`. + #[test] + fn test_scale_blit_opacity_all_rows_written() { + let w = 32usize; + let h = 32usize; + // Fully opaque red source. + let src: Vec = [200, 50, 30, 255].repeat(w * h); + // All-black destination (simulates cleared canvas). + let mut dst = vec![0u8; w * h * 4]; + + scale_blit_rgba( + &mut dst, + w as u32, + h as u32, + &src, + w as u32, + h as u32, + &Rect { x: 0, y: 0, width: w as u32, height: h as u32 }, + 0.9, + false, + false, + false, + ); + + // Every single row should have been written to (non-zero pixels). + for row in 0..h { + let row_start = row * w * 4; + let row_slice = &dst[row_start..row_start + w * 4]; + let any_written = row_slice.iter().any(|&b| b != 0); + assert!(any_written, "Row {row} was not written to (all zeros)"); + + // Verify each pixel matches the expected scalar blend. + // opacity_u16 = (0.9 * 255 + 0.5) as u16 = 230 + // sa_eff = (255 * 230 + 128) >> 8 = 229 + // Dst is black (0), so blended = src * sa_eff / 255. + let opacity_u16: u16 = 230; + let sa_eff = ((255u16 * opacity_u16 + 128) >> 8).min(255); + let expected_r = { + let blend = 200u16 * sa_eff + 128; + ((blend + (blend >> 8)) >> 8) as u8 + }; + let expected_g = { + let blend = 50u16 * sa_eff + 128; + ((blend + (blend >> 8)) >> 8) as u8 + }; + let expected_b = { + let blend = 30u16 * sa_eff + 128; + ((blend + (blend >> 8)) >> 8) as u8 + }; + for col in 0..w { + let idx = row_start + col * 4; + let got_r = dst[idx]; + let got_g = dst[idx + 1]; + let got_b = dst[idx + 2]; + let got_a = dst[idx + 3]; + + // Allow ±1 for rounding differences between SIMD and scalar paths. + assert!( + (i16::from(got_r) - i16::from(expected_r)).abs() <= 1, + "Row {row}, Col {col}: R={got_r}, expected ~{expected_r}" + ); + assert!( + (i16::from(got_g) - i16::from(expected_g)).abs() <= 1, + "Row {row}, Col {col}: G={got_g}, expected ~{expected_g}" + ); + assert!( + (i16::from(got_b) - i16::from(expected_b)).abs() <= 1, + "Row {row}, Col {col}: B={got_b}, expected ~{expected_b}" + ); + assert!(got_a > 200, "Row {row}, Col {col}: A={got_a}, expected >200"); + } + } + } + + /// Test I420→RGBA8 AVX2 kernel correctness with a multi-row buffer wide + /// enough to exercise the 8-pixel AVX2 path plus scalar remainder. + /// Verifies the OOB-safe scalar chroma reads produce identical output to + /// the scalar reference for every pixel. + #[test] + fn test_i420_to_rgba8_avx2_wide_multirow() { + // 24 pixels wide = 3 AVX2 iterations (8px each) with 0 remainder. + // 4 rows to exercise multi-row chroma subsampling. + let width: u32 = 24; + let height: u32 = 4; + let w = width as usize; + let h = height as usize; + let chroma_w = w / 2; + + // Build a varied I420 test pattern. + let mut i420_data = vec![0u8; w * h + 2 * chroma_w * (h / 2)]; + // Y plane: gradient across rows and columns. + for row in 0..h { + for col in 0..w { + i420_data[row * w + col] = (16 + ((row * w + col) * 219) / (w * h)) as u8; + } + } + // U/V planes: varying chroma values. + let u_offset = w * h; + let v_offset = u_offset + chroma_w * (h / 2); + for i in 0..chroma_w * (h / 2) { + i420_data[u_offset + i] = (64 + (i * 3) % 192) as u8; + i420_data[v_offset + i] = (32 + (i * 7) % 224) as u8; + } + + // Convert using the public function (dispatches to AVX2 on this machine). + let mut simd_out = vec![0u8; w * h * 4]; + pixel_ops::i420_to_rgba8_buf(&i420_data, width, height, &mut simd_out); + + // Compare every pixel against the scalar reference. + for row in 0..h { + for col in 0..w { + let luma = i420_data[row * w + col]; + let chroma_r = row / 2; + let chroma_c = col / 2; + let u_val = i420_data[u_offset + chroma_r * chroma_w + chroma_c]; + let v_val = i420_data[v_offset + chroma_r * chroma_w + chroma_c]; + let expected = scalar_i420_to_rgba8(luma, u_val, v_val); + let got_idx = (row * w + col) * 4; + let got = &simd_out[got_idx..got_idx + 4]; + assert_eq!( + got, &expected, + "row={row} col={col}: Y={luma} U={u_val} V={v_val} → expected {expected:?}, got {got:?}" + ); + } + } + } + + /// Test that opacity < 1.0 through `composite_frame` produces correct + /// output with no black borders when source matches canvas dimensions. + #[test] + fn test_composite_frame_opacity_no_black_borders() { + let w = 32u32; + let h = 32u32; + let frame = make_rgba_frame(w, h, 200, 100, 50, 255); + + let layer = LayerSnapshot { + data: frame.data, + width: w, + height: h, + pixel_format: PixelFormat::Rgba8, + rect: Some(Rect { x: 0, y: 0, width: w, height: h }), + opacity: 0.8, + z_index: 0, + rotation_degrees: 0.0, + mirror_horizontal: false, + mirror_vertical: false, + }; + + let mut cache = ConversionCache::new(); + let result = composite_frame(w, h, &[Some(layer)], &[], &[], None, &mut cache); + let buf = result.as_slice(); + + // Every row should have non-zero content (no black borders). + for row in 0..h as usize { + let row_start = row * w as usize * 4; + let row_end = row_start + w as usize * 4; + let any_nonzero = buf[row_start..row_end].iter().any(|&b| b != 0); + assert!(any_nonzero, "Row {row} is all zeros — black border detected"); + } + } + + /// Full-pipeline test at real dimensions (640×480): compositor blit with + /// opacity < 1.0, then RGBA→NV12→RGBA roundtrip, checking for black bands. + /// This exercises the exact pipeline the VP9 encoder sees. + #[test] + #[allow(clippy::many_single_char_names)] // Standard image-processing shorthand (w, h, r, g, b, etc.) + fn test_full_pipeline_opacity_nv12_roundtrip_no_black_bands() { + let w = 640u32; + let h = 480u32; + let wu = w as usize; + let hu = h as usize; + + // Create a colorbars-like pattern: 7 vertical bars of different colors. + let colors: [(u8, u8, u8); 7] = [ + (255, 255, 255), // white + (255, 255, 0), // yellow + (0, 255, 255), // cyan + (0, 255, 0), // green + (255, 0, 255), // magenta + (255, 0, 0), // red + (0, 0, 255), // blue + ]; + let mut src_rgba = vec![0u8; wu * hu * 4]; + for row in 0..hu { + for col in 0..wu { + let bar_idx = (col * 7) / wu; + let (r, g, b) = colors[bar_idx]; + let off = (row * wu + col) * 4; + src_rgba[off] = r; + src_rgba[off + 1] = g; + src_rgba[off + 2] = b; + src_rgba[off + 3] = 255; + } + } + + // Step 1: Blit onto canvas with opacity 0.9 (through scale_blit_rgba_rotated, + // exactly as the compositor does). + let mut canvas = vec![0u8; wu * hu * 4]; + pixel_ops::scale_blit_rgba_rotated( + &mut canvas, + w, + h, + &src_rgba, + w, + h, + &Rect { x: 0, y: 0, width: w, height: h }, + 0.9, + 0.0, + false, + false, + false, + ); + + // Verify compositor output: every row should have non-zero pixels. + for row in 0..hu { + let row_start = row * wu * 4; + let any_nonzero = canvas[row_start..row_start + wu * 4].iter().any(|&b| b != 0); + assert!(any_nonzero, "Compositor output row {row} is all zeros (black band)"); + } + + // Step 2: Convert RGBA → NV12 (exactly as the VP9 encoder does). + let chroma_w = wu.div_ceil(2); + let chroma_h = hu.div_ceil(2); + let nv12_size = wu * hu + chroma_w * 2 * chroma_h; + let mut nv12 = vec![0u8; nv12_size]; + pixel_ops::rgba8_to_nv12_buf(&canvas, w, h, &mut nv12); + + // Verify Y plane: no rows should be all-zero (Y=0 is below black level). + // With opacity 0.9 on colored bars, Y values should be well above 0. + for row in 0..hu { + let y_row = &nv12[row * wu..(row + 1) * wu]; + let max_y = *y_row.iter().max().unwrap(); + assert!(max_y > 16, "NV12 Y-plane row {row}: max Y={max_y}, expected >16 (not black)"); + } + + // Step 3: Convert NV12 → RGBA (simulates decoder display). + let mut decoded_rgba = vec![0u8; wu * hu * 4]; + pixel_ops::nv12_to_rgba8_buf(&nv12, w, h, &mut decoded_rgba); + + // Verify decoded output: every row should have non-black pixels. + for row in 0..hu { + let row_start = row * wu * 4; + let row_slice = &decoded_rgba[row_start..row_start + wu * 4]; + // Check that at least some pixels have R, G, or B > 10 (not near-black). + let has_visible = + row_slice.chunks_exact(4).any(|px| px[0] > 10 || px[1] > 10 || px[2] > 10); + assert!(has_visible, "Decoded row {row} has no visible pixels (all near-black)"); + } + } + + /// Regression test: a 4:3 source blitted onto a 16:9 canvas with opacity < 1.0 + /// must cover the entire canvas (stretch-to-fill) with no black bars. + /// Previously the near-zero rotation fast path applied an aspect-ratio-preserving + /// fit that left letterbox gaps visible as black bands when opacity < 1.0. + #[test] + fn test_mismatched_aspect_ratio_opacity_no_black_bars() { + let src_w = 640u32; + let src_h = 480u32; // 4:3 + let canvas_w = 1280u32; + let canvas_h = 720u32; // 16:9 + + // Solid green source. + let src = [0u8, 255, 0, 255].repeat((src_w * src_h) as usize); + let mut canvas = vec![0u8; (canvas_w * canvas_h * 4) as usize]; + + pixel_ops::scale_blit_rgba_rotated( + &mut canvas, + canvas_w, + canvas_h, + &src, + src_w, + src_h, + &Rect { x: 0, y: 0, width: canvas_w, height: canvas_h }, + 0.9, + 0.0, // no rotation — exercises the near-zero fast path + false, + false, + false, + ); + + // Every row should have non-zero pixels (no black bars on left/right). + for row in 0..canvas_h as usize { + let row_start = row * canvas_w as usize * 4; + let row_end = row_start + canvas_w as usize * 4; + let any_nonzero = canvas[row_start..row_end].iter().any(|&b| b != 0); + assert!(any_nonzero, "Row {row} is all zeros — black bar detected"); + } + + // Every column should have non-zero pixels (no black bars on top/bottom). + for col in 0..canvas_w as usize { + let any_nonzero = (0..canvas_h as usize).any(|row| { + let idx = (row * canvas_w as usize + col) * 4; + canvas[idx] != 0 || canvas[idx + 1] != 0 || canvas[idx + 2] != 0 + }); + assert!(any_nonzero, "Column {col} is all zeros — black bar detected"); + } + } + + /// Regression test: a 4:3 source blitted into a non-square rect with 15° + /// rotation must cover the centre of the rect (stretch-to-fill, not + /// aspect-ratio fit). Exercises the rotated path's per-axis inverse + /// scaling (`inv_scale_x` / `inv_scale_y`). + #[test] + fn test_rotated_blit_mismatched_aspect_ratio_covers_centre() { + // 4×2 red source into a 40×20 rect (2:1 aspect mismatch) at 15° on + // a 60×40 canvas. The centre of the rect (canvas pixel 30,20) must + // be covered by red source content. + let src = [255u8, 0, 0, 255].repeat(4 * 2); // 4×2 solid red + let mut dst = vec![0u8; 60 * 40 * 4]; + + scale_blit_rgba_rotated( + &mut dst, + 60, + 40, + &src, + 4, + 2, + &Rect { x: 10, y: 10, width: 40, height: 20 }, + 1.0, + 15.0, + false, + false, + false, + ); + + // Centre of the rect (canvas pixel 30, 20) should be red. + let cx = 30usize; + let cy = 20usize; + let idx = (cy * 60 + cx) * 4; + assert_eq!(dst[idx], 255, "Centre R"); + assert_eq!(dst[idx + 1], 0, "Centre G"); + assert_eq!(dst[idx + 2], 0, "Centre B"); + assert!(dst[idx + 3] > 200, "Centre A should be mostly opaque"); + } + + /// Test RGBA→NV12 AVX2 chroma conversion matches scalar reference. + /// Uses a 640-wide frame to fully exercise the AVX2 path (8 chroma samples/iter). + #[test] + #[allow(clippy::many_single_char_names)] // Standard image-processing shorthand (w, h, r, g, b, etc.) + fn test_rgba8_to_nv12_avx2_chroma_matches_scalar() { + let w = 640u32; + let h = 4u32; + let wu = w as usize; + let hu = h as usize; + let chroma_w = wu / 2; + let chroma_h = hu / 2; + + // Create a varied RGBA pattern. + let mut rgba = vec![0u8; wu * hu * 4]; + for row in 0..hu { + for col in 0..wu { + let off = (row * wu + col) * 4; + rgba[off] = ((col * 3 + row * 7) % 256) as u8; // R + rgba[off + 1] = ((col * 5 + row * 11) % 256) as u8; // G + rgba[off + 2] = ((col * 7 + row * 13) % 256) as u8; // B + rgba[off + 3] = 255; // A + } + } + + // Convert using the public function (dispatches to AVX2). + let nv12_size = wu * hu + chroma_w * 2 * chroma_h; + let mut nv12_simd = vec![0u8; nv12_size]; + pixel_ops::rgba8_to_nv12_buf(&rgba, w, h, &mut nv12_simd); + + // Compute scalar reference for the chroma plane. + let y_size = wu * hu; + for crow in 0..chroma_h { + let r0 = crow * 2; + for ccol in 0..chroma_w { + let c0 = ccol * 2; + let mut sr = 0i32; + let mut sg = 0i32; + let mut sb = 0i32; + let mut count = 0i32; + for dr in 0..2u32 { + let rr = r0 + dr as usize; + if rr >= hu { + continue; + } + for dc in 0..2u32 { + let cc = c0 + dc as usize; + if cc < wu { + let off = (rr * wu + cc) * 4; + sr += i32::from(rgba[off]); + sg += i32::from(rgba[off + 1]); + sb += i32::from(rgba[off + 2]); + count += 1; + } + } + } + let r_avg = sr / count; + let g_avg = sg / count; + let b_avg = sb / count; + let expected_u = ((-38 * r_avg - 74 * g_avg + 112 * b_avg + 128) >> 8) + 128; + let expected_v = ((112 * r_avg - 94 * g_avg - 18 * b_avg + 128) >> 8) + 128; + let expected_u = expected_u.clamp(0, 255) as u8; + let expected_v = expected_v.clamp(0, 255) as u8; + + let uv_off = y_size + crow * chroma_w * 2 + ccol * 2; + let got_u = nv12_simd[uv_off]; + let got_v = nv12_simd[uv_off + 1]; + + // Allow ±2 for rounding differences between SIMD and scalar. + assert!( + (i16::from(got_u) - i16::from(expected_u)).abs() <= 2, + "crow={crow} ccol={ccol}: U got={got_u}, expected={expected_u}" + ); + assert!( + (i16::from(got_v) - i16::from(expected_v)).abs() <= 2, + "crow={crow} ccol={ccol}: V got={got_v}, expected={expected_v}" + ); + } + } + + // Also verify Y plane matches scalar reference. + for row in 0..hu { + for col in 0..wu { + let off = (row * wu + col) * 4; + let r = i32::from(rgba[off]); + let g = i32::from(rgba[off + 1]); + let b = i32::from(rgba[off + 2]); + let expected_y = + (((66 * r + 129 * g + 25 * b + 128) >> 8) + 16).clamp(0, 255) as u8; + let got_y = nv12_simd[row * wu + col]; + assert!( + (i16::from(got_y) - i16::from(expected_y)).abs() <= 1, + "row={row} col={col}: Y got={got_y}, expected={expected_y}" + ); + } + } + } +} diff --git a/crates/nodes/src/video/compositor/overlay.rs b/crates/nodes/src/video/compositor/overlay.rs new file mode 100644 index 00000000..8049e748 --- /dev/null +++ b/crates/nodes/src/video/compositor/overlay.rs @@ -0,0 +1,373 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Overlay decoding and rasterization for the video compositor. + +use super::config::{ImageOverlayConfig, Rect, TextOverlayConfig}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, LazyLock, Mutex}; +use streamkit_core::StreamKitError; + +// ── Decoded overlay bitmap ────────────────────────────────────────────────── + +/// A pre-decoded RGBA bitmap overlay ready for per-frame blitting. +#[derive(Clone)] +pub struct DecodedOverlay { + pub rgba_data: Vec, + pub width: u32, + pub height: u32, + pub rect: Rect, + pub opacity: f32, + /// Clockwise rotation in degrees around the rect centre. + pub rotation_degrees: f32, + /// Visual stacking order for unified z-sorting with video layers. + pub z_index: i32, + /// Mirror horizontally (flip left ↔ right). + pub mirror_horizontal: bool, + /// Mirror vertically (flip top ↔ bottom). + pub mirror_vertical: bool, +} + +/// Decode a base64-encoded image (PNG/JPEG) into an RGBA8 bitmap. +/// +/// # Errors +/// +/// Returns an error if the base64 data is invalid or the image cannot be decoded. +pub fn decode_image_overlay(config: &ImageOverlayConfig) -> Result { + use image::GenericImageView; + + use base64::Engine; + let bytes = + base64::engine::general_purpose::STANDARD.decode(&config.data_base64).map_err(|e| { + StreamKitError::Configuration(format!("Invalid base64 in image overlay: {e}")) + })?; + + let img = image::load_from_memory(&bytes).map_err(|e| { + StreamKitError::Configuration(format!("Failed to decode image overlay: {e}")) + })?; + + let rgba = img.to_rgba8(); + let (w, h) = img.dimensions(); + + let target_w = config.transform.rect.width; + let target_h = config.transform.rect.height; + + // Pre-scale the decoded image to fit within the target rect while + // preserving the source aspect ratio. This ensures the per-frame + // `scale_blit_rgba_rotated` call hits the identity-scale fast path + // (direct memcpy) and the image is never stretched. + if target_w > 0 && target_h > 0 && (w != target_w || h != target_h) { + // Aspect-ratio-preserving fit: scale so the image fits inside + // the target box without distortion. + #[allow(clippy::cast_precision_loss)] + let scale = { + let sw = w as f32; + let sh = h as f32; + (target_w as f32 / sw).min(target_h as f32 / sh) + }; + #[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_precision_loss + )] + let fit_w = ((w as f32 * scale).round() as u32).max(1); + #[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_precision_loss + )] + let fit_h = ((h as f32 * scale).round() as u32).max(1); + + let raw = rgba.into_raw(); + let scaled = prescale_rgba(&raw, w, h, fit_w, fit_h); + + // Adjust the rect to match the fitted dimensions so the blit + // stays on the identity-scale path and the image is centred + // within the originally requested area. + let mut rect = config.transform.rect.clone(); + rect.x += (target_w.cast_signed() - fit_w.cast_signed()) / 2; + rect.y += (target_h.cast_signed() - fit_h.cast_signed()) / 2; + rect.width = fit_w; + rect.height = fit_h; + + Ok(DecodedOverlay { + rgba_data: scaled, + width: fit_w, + height: fit_h, + rect, + opacity: config.transform.opacity, + rotation_degrees: config.transform.rotation_degrees, + z_index: config.transform.z_index, + mirror_horizontal: config.transform.mirror_horizontal, + mirror_vertical: config.transform.mirror_vertical, + }) + } else { + Ok(DecodedOverlay { + rgba_data: rgba.into_raw(), + width: w, + height: h, + rect: config.transform.rect.clone(), + opacity: config.transform.opacity, + rotation_degrees: config.transform.rotation_degrees, + z_index: config.transform.z_index, + mirror_horizontal: config.transform.mirror_horizontal, + mirror_vertical: config.transform.mirror_vertical, + }) + } +} + +/// Bilinear-filtered scale of an RGBA8 buffer from `(sw, sh)` to `(dw, dh)`. +/// Uses the `image` crate's `resize` with `Triangle` (bilinear) filter for +/// high-quality prescaling — much better than nearest-neighbor for images +/// containing text or fine detail. Called once at config time so the +/// per-frame blit is a 1:1 copy. +fn prescale_rgba(src: &[u8], sw: u32, sh: u32, dw: u32, dh: u32) -> Vec { + // Invariant: caller guarantees src.len() == sw * sh * 4. + #[allow(clippy::expect_used)] + // from_raw only fails if buffer length != w*h*4; caller guarantees this + let src_img = image::RgbaImage::from_raw(sw, sh, src.to_vec()) + .expect("prescale_rgba: source dimensions do not match buffer length"); + let resized = image::imageops::resize(&src_img, dw, dh, image::imageops::FilterType::Triangle); + resized.into_raw() +} + +// ── Bundled font data ──────────────────────────────────────────────────────── + +use crate::video::fonts; + +// ── Parsed-font cache ─────────────────────────────────────────────────────── + +/// Cache key identifying a font source. +/// +/// Bundled fonts are keyed by their static name. User-provided `font_path` +/// sources use the filesystem path string. Inline base64 data is keyed by +/// a hash of the base64 string so the cache map does not retain what may be +/// a several-hundred-KiB string per font. +#[derive(Hash, Eq, PartialEq)] +enum FontKey { + /// A font from the compile-time bundled set (keyed by name). + Bundled(&'static str), + /// A user-provided font loaded from a filesystem path. + Path(String), + /// Inline base64-encoded font data (keyed by content hash). + InlineHash(u64), +} + +/// Process-wide cache of parsed `fontdue::Font` objects. +/// +/// `fontdue::Font::from_bytes` parses the full TTF/OTF table set and is +/// expensive (~3.5 s cumulative in profiling when overlay parameters update +/// frequently). Caching the parsed result keyed by font identity means the +/// parse happens once per distinct font for the lifetime of the process; +/// subsequent `load_font` calls for the same source are an `Arc::clone`. +/// +/// The set of distinct fonts in any reasonable pipeline is tiny (bounded by +/// the bundled set + whatever the user injects), so unbounded growth is not a +/// concern. The lock is held only for the map lookup / insert, never across +/// the parse itself. +static FONT_CACHE: LazyLock>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Lazy loader for raw font bytes. Constructed cheaply by +/// [`resolve_font_source`] so that file I/O and base64 decoding are deferred +/// until after a cache miss is confirmed. +type FontBytesLoader<'a> = Box Result, String> + 'a>; + +/// Resolve a [`TextOverlayConfig`]'s font-source fields to a [`FontKey`] and a +/// lazy byte loader, following the same precedence as [`load_font`]: +/// `font_data_base64` > `font_name` > `font_path` > bundled default. +/// +/// Bundled fonts (via `font_name` or the default) are compiled into the binary +/// and always available — no filesystem dependency. `font_path` still supports +/// loading arbitrary external fonts from the filesystem. +/// +/// Returning a boxed closure lets the caller skip base64 decode / file I/O +/// entirely on a cache hit. +fn resolve_font_source(config: &TextOverlayConfig) -> (FontKey, FontBytesLoader<'_>) { + if let Some(ref b64) = config.font_data_base64 { + let mut h = std::collections::hash_map::DefaultHasher::new(); + b64.hash(&mut h); + let key = FontKey::InlineHash(h.finish()); + let loader = move || { + use base64::Engine; + base64::engine::general_purpose::STANDARD + .decode(b64) + .map_err(|e| format!("Invalid base64 in font_data_base64: {e}")) + }; + return (key, Box::new(loader)); + } + + if let Some(ref name) = config.font_name { + if let Some(data) = fonts::bundled_font_by_name(name) { + let bundled = fonts::BUNDLED_FONTS + .iter() + .find(|f| f.name == name.as_str()) + .map_or("dejavu-sans", |f| f.name); + let key = FontKey::Bundled(bundled); + let loader = move || Ok(data.to_vec()); + return (key, Box::new(loader)); + } + // Unknown font name — fall back to the default with a warning rather + // than erroring out, so overlays remain readable when legacy or + // unrecognised names are passed (e.g. after removing Liberation/FreeFont). + tracing::warn!( + "Unknown font name '{name}', falling back to default (dejavu-sans). \ + Available: {}", + fonts::bundled_font_names() + ); + let key = FontKey::Bundled("dejavu-sans"); + let loader = || Ok(fonts::DEFAULT_FONT_DATA.to_vec()); + return (key, Box::new(loader)); + } + + if let Some(ref path) = config.font_path { + let key = FontKey::Path(path.clone()); + let path = path.clone(); + let loader = move || { + std::fs::read(&path).map_err(|e| format!("Failed to read font file '{path}': {e}")) + }; + return (key, Box::new(loader)); + } + + // Default: embedded DejaVu Sans. + let key = FontKey::Bundled("dejavu-sans"); + let loader = || Ok(fonts::DEFAULT_FONT_DATA.to_vec()); + (key, Box::new(loader)) +} + +/// Load font data, trying (in order): +/// 1. `font_data_base64` (inline base64-encoded TTF/OTF) +/// 2. `font_name` (named font from the bundled set) +/// 3. `font_path` (filesystem path for external/custom fonts) +/// 4. Bundled default (DejaVu Sans, embedded at compile time) +/// +/// Parsed fonts are cached in [`FONT_CACHE`] keyed by the resolved source +/// identity, so repeated calls for the same font are an `Arc::clone` rather +/// than a fresh parse. +fn load_font(config: &TextOverlayConfig) -> Result, String> { + let (key, load_bytes) = resolve_font_source(config); + + // Fast path: cache hit. Lock scope limited to the lookup. + if let Ok(cache) = FONT_CACHE.lock() { + if let Some(font) = cache.get(&key) { + return Ok(Arc::clone(font)); + } + } + + // Miss: do the expensive work (I/O + parse) *outside* the lock. + let font_bytes = load_bytes()?; + let font = Arc::new( + fontdue::Font::from_bytes(font_bytes, fontdue::FontSettings::default()) + .map_err(|e| format!("Failed to parse font: {e}"))?, + ); + + // Insert. If the mutex is poisoned we simply skip caching — the caller + // still gets a valid font, just without the memoisation benefit. + if let Ok(mut cache) = FONT_CACHE.lock() { + cache.entry(key).or_insert_with(|| Arc::clone(&font)); + } + + Ok(font) +} + +/// Rasterize a text overlay into an RGBA8 bitmap using `fontdue` for real +/// font glyph rendering. Falls back to solid-rectangle placeholders when +/// font loading fails so the node keeps running. +/// +/// Supports explicit newlines (`\n`) and automatic word-wrapping when the +/// overlay rect has a non-zero width. The bitmap dimensions are expanded +/// to fit the measured (possibly multi-line) text so that nothing is +/// clipped. +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_precision_loss)] +pub fn rasterize_text_overlay(config: &TextOverlayConfig) -> DecodedOverlay { + // Attempt to load the font; fall back to rectangle placeholders on error. + let font = match load_font(config) { + Ok(f) => Some(f), + Err(e) => { + tracing::warn!("Font loading failed, using placeholder rectangles: {e}"); + None + }, + }; + + let font_size = config.font_size.max(1) as f32; + let wrap_width = config.transform.rect.width; + + // Measure actual text dimensions so the bitmap is large enough to hold + // the full rendered string without clipping. When a wrap width is set + // the text is word-wrapped and may span multiple lines. + let (measured_w, measured_h) = font.as_ref().map_or_else( + || { + // Fallback estimate for placeholder rectangles. + let glyph_w = config.font_size.max(1) * 3 / 5; + let est_w = glyph_w * config.text.chars().count() as u32; + let est_h = (font_size * 1.4).ceil() as u32; + (est_w, est_h) + }, + |f| crate::video::measure_text_wrapped(f, font_size, &config.text, wrap_width), + ); + + let w = config.transform.rect.width.max(measured_w).max(1); + let h = config.transform.rect.height.max(measured_h).max(1); + + let total_bytes = (w as usize) * (h as usize) * 4; + let mut rgba_data = vec![0u8; total_bytes]; + + if let Some(font) = font { + // ── Real font rendering via shared utility (multi-line aware) ──── + crate::video::blit_text_wrapped( + &mut rgba_data, + w, + h, + &font, + config.font_size.max(1) as f32, + &config.text, + 0, + 0, + config.color, + wrap_width, + ); + } else { + // ── Fallback: filled rectangle per glyph (placeholder) ────────── + let [cr, cg, cb, ca] = config.color; + let stride = w as usize * 4; + let glyph_w = (config.font_size.max(1) * 3 / 5) as usize; + let glyph_h = config.font_size.max(1) as usize; + + for (i, _ch) in config.text.chars().enumerate() { + let x = i * glyph_w; + if x + glyph_w > w as usize { + break; + } + for row in 0..glyph_h.min(h as usize) { + for col in x..x + glyph_w { + let off = row * stride + col * 4; + rgba_data[off] = cr; + rgba_data[off + 1] = cg; + rgba_data[off + 2] = cb; + rgba_data[off + 3] = ca; + } + } + } + } + + DecodedOverlay { + rgba_data, + width: w, + height: h, + rect: { + // Use the expanded dimensions so the blit renders the full bitmap + // without clipping text that exceeds the original rect. + let mut r = config.transform.rect.clone(); + r.width = w; + r.height = h; + r + }, + opacity: config.transform.opacity, + rotation_degrees: config.transform.rotation_degrees, + z_index: config.transform.z_index, + mirror_horizontal: config.transform.mirror_horizontal, + mirror_vertical: config.transform.mirror_vertical, + } +} diff --git a/crates/nodes/src/video/compositor/pixel_ops/blit.rs b/crates/nodes/src/video/compositor/pixel_ops/blit.rs new file mode 100644 index 00000000..a7b430f2 --- /dev/null +++ b/crates/nodes/src/video/compositor/pixel_ops/blit.rs @@ -0,0 +1,1086 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! RGBA8 blitting operations for the video compositor. +//! +//! Contains: +//! - [`scale_blit_rgba`]: axis-aligned scale + blit with nearest-neighbor sampling. +//! - [`scale_blit_rgba_rotated`]: rotated scale + blit with anti-aliased edges. +//! +//! Both functions use row-level parallelism via `rayon` when the blit region +//! is large enough to amortise the thread-pool dispatch overhead. + +use super::{blend_u8, rayon_chunk_rows, RAYON_ROW_THRESHOLD}; +use crate::video::compositor::config::Rect; + +#[cfg(target_arch = "x86_64")] +use super::simd::{ + all_alpha_opaque_avx2, all_alpha_opaque_sse2, blend_4px_alpha_sse2, blend_4px_opaque_sse2, + blend_8px_alpha_avx2, blend_8px_opaque_avx2, read_rgba_u32, +}; + +// ── Scalar blend helper ───────────────────────────────────────────────────── + +/// Blend a single source pixel onto a destination row slice at `dst_off`. +/// +/// Handles fully-opaque, semi-transparent, and fully-transparent cases. +/// `opacity_u16` is a 0..255 multiplier applied to the source alpha, or 256 +/// as a sentinel meaning "fully opaque, skip per-pixel opacity multiply". +/// +/// This is the shared scalar path used by both the x86_64 remainder loop and +/// the non-x86_64 fallback in [`scale_blit_rgba_rotated`]. +#[allow(clippy::inline_always)] +#[inline(always)] +fn blend_pixel_scalar( + row_slice: &mut [u8], + dst_off: usize, + src: &[u8], + src_idx: usize, + opacity_u16: u16, +) { + let ir = src[src_idx]; + let ig = src[src_idx + 1]; + let ib = src[src_idx + 2]; + let mut ia = src[src_idx + 3]; + + if opacity_u16 < 256 { + ia = ((u16::from(ia) * opacity_u16 + 128) >> 8).min(255) as u8; + } + if ia > 0 && dst_off + 3 < row_slice.len() { + if ia == 255 { + row_slice[dst_off] = ir; + row_slice[dst_off + 1] = ig; + row_slice[dst_off + 2] = ib; + row_slice[dst_off + 3] = 255; + } else { + let a16 = u16::from(ia); + row_slice[dst_off] = blend_u8(ir, row_slice[dst_off], a16); + row_slice[dst_off + 1] = blend_u8(ig, row_slice[dst_off + 1], a16); + row_slice[dst_off + 2] = blend_u8(ib, row_slice[dst_off + 2], a16); + let da = u16::from(row_slice[dst_off + 3]); + row_slice[dst_off + 3] = (a16 + ((da * (255 - a16) + 128) >> 8)).min(255) as u8; + } + } +} + +// ── Axis-aligned blit ─────────────────────────────────────────────────────── + +/// Scale and blit a source RGBA8 buffer onto a destination RGBA8 buffer at the +/// given destination rectangle. Uses nearest-neighbor sampling and clips to +/// canvas bounds. +/// +/// Rows are processed in parallel via `rayon` when the blit region is large +/// enough to benefit from multi-core dispatch. +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::too_many_arguments)] +pub fn scale_blit_rgba( + dst: &mut [u8], + dst_width: u32, + dst_height: u32, + src: &[u8], + src_width: u32, + src_height: u32, + dst_rect: &Rect, + opacity: f32, + #[allow(unused_variables)] src_opaque: bool, + mirror_h: bool, + mirror_v: bool, +) { + use rayon::prelude::*; + + if src_width == 0 || src_height == 0 || dst_rect.width == 0 || dst_rect.height == 0 { + return; + } + + let dw = dst_width as usize; + let dh = dst_height as usize; + let sw = src_width as usize; + let sh = src_height as usize; + let rw = dst_rect.width as usize; + let rh = dst_rect.height as usize; + + // Compute the visible region after clipping the (possibly negative) + // rect position to the canvas bounds. + let (rx, src_col_skip) = if dst_rect.x < 0 { + (0usize, (-dst_rect.x) as usize) + } else { + (dst_rect.x as usize, 0usize) + }; + let (ry, src_row_skip) = if dst_rect.y < 0 { + (0usize, (-dst_rect.y) as usize) + } else { + (dst_rect.y as usize, 0usize) + }; + + // Clamp the number of rows we actually process to the canvas height. + let effective_rh = rh.saturating_sub(src_row_skip).min(dh.saturating_sub(ry)); + if effective_rh == 0 { + return; + } + + // Clamp the number of columns to the canvas width. + let effective_rect_w = rw.saturating_sub(src_col_skip).min(dw.saturating_sub(rx)); + if effective_rect_w == 0 { + return; + } + + // Split the destination buffer into per-row slices so that each row can + // be processed independently (and therefore in parallel). + let row_stride = dw * 4; + + // We need to give each row its own mutable slice. Split the dst buffer + // at the first output row. + let first_row_byte = ry * row_stride; + let dst_rows = &mut dst[first_row_byte..]; + + // ── Identity-scale fast path ─────────────────────────────────────── + // When source dimensions exactly match the destination rect and opacity + // is fully opaque, we can avoid per-pixel scaling entirely and use + // direct row copies (memcpy) for fully-opaque source rows. + // + // Rows are processed in parallel via `rayon` when the blit region is + // large enough to benefit from multi-core dispatch. + if rw == sw && rh == sh && opacity >= 1.0 && src_col_skip == 0 && src_row_skip == 0 { + let src_row_bytes = sw * 4; + let copy_bytes = effective_rect_w * 4; + // Pre-validate that the source buffer can satisfy all rows, + // so the inner closure doesn't need per-row bounds checks. + let max_src_end = (effective_rh.saturating_sub(1)) * src_row_bytes + copy_bytes; + if max_src_end > src.len() { + // Fall through to the scaled path for safety. + } else { + let dst_start = rx * 4; + + let blit_identity_row = |dy: usize, row_slice: &mut [u8]| { + let src_start = dy * src_row_bytes; + let src_row = &src[src_start..src_start + copy_bytes]; + let dst_end = dst_start + copy_bytes; + if dst_end > row_slice.len() { + return; + } + // When the caller guarantees all source pixels are opaque + // (e.g. YUV→RGBA conversion always writes alpha = 255), + // skip the per-row alpha scan entirely. + let all_opaque = if src_opaque { + true + } else { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + unsafe { all_alpha_opaque_avx2(src_row) } + } else { + unsafe { all_alpha_opaque_sse2(src_row) } + } + } + #[cfg(not(target_arch = "x86_64"))] + { + src_row.chunks_exact(4).all(|px| px[3] == 255) + } + }; + + if all_opaque { + row_slice[dst_start..dst_end].copy_from_slice(src_row); + } else { + // Per-pixel alpha blend (identity scale, so sx == dx). + for dx in 0..effective_rect_w { + let si = dx * 4; + let sa = src_row[si + 3]; + if sa == 255 { + row_slice[dst_start + dx * 4..dst_start + dx * 4 + 4] + .copy_from_slice(&src_row[si..si + 4]); + } else if sa > 0 { + let di = dst_start + dx * 4; + let a16 = u16::from(sa); + row_slice[di] = blend_u8(src_row[si], row_slice[di], a16); + row_slice[di + 1] = blend_u8(src_row[si + 1], row_slice[di + 1], a16); + row_slice[di + 2] = blend_u8(src_row[si + 2], row_slice[di + 2], a16); + let da = u16::from(row_slice[di + 3]); + row_slice[di + 3] = + (a16 + ((da * (255 - a16) + 128) >> 8)).min(255) as u8; + } + } + } + }; + + if effective_rh >= RAYON_ROW_THRESHOLD { + dst_rows.par_chunks_mut(row_stride).take(effective_rh).enumerate().for_each( + |(dy, row_slice)| { + blit_identity_row(dy, row_slice); + }, + ); + } else { + for (dy, row_slice) in + dst_rows.chunks_mut(row_stride).take(effective_rh).enumerate() + { + blit_identity_row(dy, row_slice); + } + } + return; + } + } + + // ── Scaled blit path ─────────────────────────────────────────────── + // Precompute the source-X lookup table once. This replaces the per-pixel + // `(dx + src_col_skip) * sw / rw` integer division with a single table + // lookup in the inner blit loops. + let x_map: Vec = (0..effective_rect_w) + .map(|dx| { + let sx = (dx + src_col_skip) * sw / rw; + if mirror_h { + sw.saturating_sub(1).saturating_sub(sx) + } else { + sx + } + }) + .collect(); + + if effective_rh >= RAYON_ROW_THRESHOLD { + dst_rows.par_chunks_mut(row_stride).take(effective_rh).enumerate().for_each( + |(dy, row_slice)| { + let sy_raw = (dy + src_row_skip) * sh / rh; + let sy = + if mirror_v { sh.saturating_sub(1).saturating_sub(sy_raw) } else { sy_raw }; + blit_row(row_slice, rx, effective_rect_w, src, sw, sy, opacity, &x_map); + }, + ); + } else { + for (dy, row_slice) in dst_rows.chunks_mut(row_stride).take(effective_rh).enumerate() { + let sy_raw = (dy + src_row_skip) * sh / rh; + let sy = if mirror_v { sh.saturating_sub(1).saturating_sub(sy_raw) } else { sy_raw }; + blit_row(row_slice, rx, effective_rect_w, src, sw, sy, opacity, &x_map); + } + } +} + +/// Blit a single row of the source onto a destination row slice. +/// +/// This is the inner kernel extracted so that `scale_blit_rgba` can dispatch +/// rows in parallel. The `row_slice` covers exactly one destination row +/// starting at pixel column 0 (i.e. byte offset `rx * 4` is the first column +/// we write to). +/// +/// `x_map` is a precomputed table mapping each destination column to the +/// corresponding source column, eliminating per-pixel integer division. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::too_many_arguments, + clippy::inline_always +)] +#[inline(always)] +fn blit_row( + row_slice: &mut [u8], + rx: usize, + effective_rw: usize, + src: &[u8], + sw: usize, + sy: usize, + opacity: f32, + x_map: &[usize], +) { + // Fast path: when opacity is 1.0, we can skip the f32 multiply on alpha + // and branch more cheaply. + if opacity >= 1.0 { + blit_row_opaque(row_slice, rx, effective_rw, src, sw, sy, x_map); + } else { + blit_row_alpha(row_slice, rx, effective_rw, src, sw, sy, opacity, x_map); + } +} + +/// Inner blit for fully-opaque layers (`opacity >= 1.0`). Skips the +/// per-pixel f32 multiply on the source alpha channel. +/// +/// Uses integer-only alpha blending for semi-transparent source pixels. +/// `x_map` provides precomputed source-X indices (one per destination column). +/// +/// On x86-64, processes 4 pixels at a time using SSE2 SIMD when the row is +/// wide enough and bounds can be pre-validated. +/// AVX2 inner loop for opaque blitting — processes 8 pixels at a time. +/// +/// Extracted into its own `#[target_feature]` function so LLVM can inline the +/// per-8px SIMD helpers without the target-feature mismatch barrier that would +/// exist if they were called from a non-AVX2 caller. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn blit_row_opaque_avx2_loop( + row_slice: &mut [u8], + rx: usize, + effective_rw: usize, + src: &[u8], + src_row_base: usize, + x_map: &[usize], +) -> usize { + let chunks8 = effective_rw / 8; + for c in 0..chunks8 { + let dx = c * 8; + let pixels = [ + read_rgba_u32(src, src_row_base + x_map[dx] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 1] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 2] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 3] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 4] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 5] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 6] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 7] * 4), + ]; + blend_8px_opaque_avx2(row_slice.as_mut_ptr().add((rx + dx) * 4), pixels); + } + chunks8 * 8 +} + +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::too_many_arguments, + clippy::suboptimal_flops, + clippy::inline_always, + // dx is used as both x_map index and dst offset, so an iterator is non-trivial. + clippy::needless_range_loop, + // AVX2 block has side-effects (SIMD writes) before assigning dx_start. + clippy::useless_let_if_seq +)] +#[inline(always)] +fn blit_row_opaque( + row_slice: &mut [u8], + rx: usize, + effective_rw: usize, + src: &[u8], + sw: usize, + sy: usize, + x_map: &[usize], +) { + let src_row_base = sy * sw * 4; + + // ── SIMD fast path: AVX2 (8px) → SSE2 (4px) → scalar tail ───────── + #[cfg(target_arch = "x86_64")] + { + // Pre-validate bounds so the inner SIMD loop is branch-free. + let src_row_end = src_row_base + sw * 4; + let dst_end = (rx + effective_rw) * 4; + if src_row_end <= src.len() && dst_end <= row_slice.len() { + let mut dx_start = 0usize; + + // AVX2: process 8 pixels at a time. + if is_x86_feature_detected!("avx2") { + dx_start = unsafe { + blit_row_opaque_avx2_loop(row_slice, rx, effective_rw, src, src_row_base, x_map) + }; + } + + // SSE2: process remaining pixels in 4-pixel chunks. + let chunks4 = (effective_rw - dx_start) / 4; + for c in 0..chunks4 { + let dx = dx_start + c * 4; + unsafe { + let pixels = [ + read_rgba_u32(src, src_row_base + x_map[dx] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 1] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 2] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 3] * 4), + ]; + blend_4px_opaque_sse2(row_slice.as_mut_ptr().add((rx + dx) * 4), pixels); + } + } + + // Scalar tail for remaining 0-3 pixels. + let tail_start = dx_start + chunks4 * 4; + for dx in tail_start..effective_rw { + let sx = x_map[dx]; + let src_idx = src_row_base + sx * 4; + let sr = src[src_idx]; + let sg = src[src_idx + 1]; + let sb = src[src_idx + 2]; + let sa = src[src_idx + 3]; + let dst_idx = (rx + dx) * 4; + if sa == 255 { + row_slice[dst_idx] = sr; + row_slice[dst_idx + 1] = sg; + row_slice[dst_idx + 2] = sb; + row_slice[dst_idx + 3] = 255; + } else if sa > 0 { + let a16 = u16::from(sa); + row_slice[dst_idx] = blend_u8(sr, row_slice[dst_idx], a16); + row_slice[dst_idx + 1] = blend_u8(sg, row_slice[dst_idx + 1], a16); + row_slice[dst_idx + 2] = blend_u8(sb, row_slice[dst_idx + 2], a16); + let da = u16::from(row_slice[dst_idx + 3]); + row_slice[dst_idx + 3] = (a16 + ((da * (255 - a16) + 128) >> 8)).min(255) as u8; + } + } + return; + } + } + + // ── Scalar fallback (bounds-checked per pixel) ───────────────────── + for dx in 0..effective_rw { + let sx = x_map[dx]; + let src_idx = src_row_base + sx * 4; + if src_idx + 3 >= src.len() { + continue; + } + + let sr = src[src_idx]; + let sg = src[src_idx + 1]; + let sb = src[src_idx + 2]; + let sa = src[src_idx + 3]; + + let dst_idx = (rx + dx) * 4; + if dst_idx + 3 >= row_slice.len() { + continue; + } + + if sa == 255 { + row_slice[dst_idx] = sr; + row_slice[dst_idx + 1] = sg; + row_slice[dst_idx + 2] = sb; + row_slice[dst_idx + 3] = 255; + } else if sa > 0 { + let a16 = u16::from(sa); + row_slice[dst_idx] = blend_u8(sr, row_slice[dst_idx], a16); + row_slice[dst_idx + 1] = blend_u8(sg, row_slice[dst_idx + 1], a16); + row_slice[dst_idx + 2] = blend_u8(sb, row_slice[dst_idx + 2], a16); + let da = u16::from(row_slice[dst_idx + 3]); + row_slice[dst_idx + 3] = (a16 + ((da * (255 - a16) + 128) >> 8)).min(255) as u8; + } + } +} + +/// Inner blit for layers with fractional opacity (`opacity < 1.0`). +/// Applies the opacity multiplier to every source pixel's alpha channel. +/// +/// Uses integer-only alpha blending. +/// `x_map` provides precomputed source-X indices (one per destination column). +/// +/// On x86-64, processes 4 pixels at a time using SSE2 SIMD when the row is +/// wide enough and bounds can be pre-validated. +/// AVX2 inner loop for alpha blitting — processes 8 pixels at a time. +/// +/// Same rationale as [`blit_row_opaque_avx2_loop`]: keeps the entire loop inside +/// a `#[target_feature(enable = "avx2")]` scope so LLVM can inline the SIMD +/// helpers without a target-feature mismatch barrier. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn blit_row_alpha_avx2_loop( + row_slice: &mut [u8], + rx: usize, + effective_rw: usize, + src: &[u8], + src_row_base: usize, + x_map: &[usize], + opacity_u16: u16, +) -> usize { + let chunks8 = effective_rw / 8; + for c in 0..chunks8 { + let dx = c * 8; + let pixels = [ + read_rgba_u32(src, src_row_base + x_map[dx] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 1] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 2] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 3] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 4] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 5] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 6] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 7] * 4), + ]; + blend_8px_alpha_avx2(row_slice.as_mut_ptr().add((rx + dx) * 4), pixels, opacity_u16); + } + chunks8 * 8 +} + +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::too_many_arguments, + clippy::suboptimal_flops, + clippy::inline_always, + // dx is used as both x_map index and dst offset, so an iterator is non-trivial. + clippy::needless_range_loop, + // AVX2 block has side-effects (SIMD writes) before assigning dx_start. + clippy::useless_let_if_seq +)] +#[inline(always)] +fn blit_row_alpha( + row_slice: &mut [u8], + rx: usize, + effective_rw: usize, + src: &[u8], + sw: usize, + sy: usize, + opacity: f32, + x_map: &[usize], +) { + // Pre-compute opacity as a 0..255 integer multiplier. + let opacity_u16 = (opacity * 255.0 + 0.5) as u16; + let src_row_base = sy * sw * 4; + + // ── SIMD fast path: AVX2 (8px) → SSE2 (4px) → scalar tail ───────── + #[cfg(target_arch = "x86_64")] + { + let src_row_end = src_row_base + sw * 4; + let dst_end = (rx + effective_rw) * 4; + if src_row_end <= src.len() && dst_end <= row_slice.len() { + let mut dx_start = 0usize; + + // AVX2: process 8 pixels at a time. + if is_x86_feature_detected!("avx2") { + dx_start = unsafe { + blit_row_alpha_avx2_loop( + row_slice, + rx, + effective_rw, + src, + src_row_base, + x_map, + opacity_u16, + ) + }; + } + + // SSE2: process remaining pixels in 4-pixel chunks. + let chunks4 = (effective_rw - dx_start) / 4; + for c in 0..chunks4 { + let dx = dx_start + c * 4; + unsafe { + let pixels = [ + read_rgba_u32(src, src_row_base + x_map[dx] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 1] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 2] * 4), + read_rgba_u32(src, src_row_base + x_map[dx + 3] * 4), + ]; + blend_4px_alpha_sse2( + row_slice.as_mut_ptr().add((rx + dx) * 4), + pixels, + opacity_u16, + ); + } + } + + // Scalar tail. + let tail_start = dx_start + chunks4 * 4; + for dx in tail_start..effective_rw { + let sx = x_map[dx]; + let src_idx = src_row_base + sx * 4; + let sr = src[src_idx]; + let sg = src[src_idx + 1]; + let sb = src[src_idx + 2]; + let sa = src[src_idx + 3]; + let dst_idx = (rx + dx) * 4; + let sa_eff = ((u16::from(sa) * opacity_u16 + 128) >> 8).min(255); + if sa_eff == 255 { + row_slice[dst_idx] = sr; + row_slice[dst_idx + 1] = sg; + row_slice[dst_idx + 2] = sb; + row_slice[dst_idx + 3] = 255; + } else if sa_eff > 0 { + row_slice[dst_idx] = blend_u8(sr, row_slice[dst_idx], sa_eff); + row_slice[dst_idx + 1] = blend_u8(sg, row_slice[dst_idx + 1], sa_eff); + row_slice[dst_idx + 2] = blend_u8(sb, row_slice[dst_idx + 2], sa_eff); + let da = u16::from(row_slice[dst_idx + 3]); + row_slice[dst_idx + 3] = + (sa_eff + ((da * (255 - sa_eff) + 128) >> 8)).min(255) as u8; + } + } + return; + } + } + + // ── Scalar fallback ──────────────────────────────────────────────── + for dx in 0..effective_rw { + let sx = x_map[dx]; + let src_idx = src_row_base + sx * 4; + if src_idx + 3 >= src.len() { + continue; + } + + let sr = src[src_idx]; + let sg = src[src_idx + 1]; + let sb = src[src_idx + 2]; + let sa = src[src_idx + 3]; + + let dst_idx = (rx + dx) * 4; + if dst_idx + 3 >= row_slice.len() { + continue; + } + + let sa_eff = ((u16::from(sa) * opacity_u16 + 128) >> 8).min(255); + if sa_eff == 255 { + row_slice[dst_idx] = sr; + row_slice[dst_idx + 1] = sg; + row_slice[dst_idx + 2] = sb; + row_slice[dst_idx + 3] = 255; + } else if sa_eff > 0 { + row_slice[dst_idx] = blend_u8(sr, row_slice[dst_idx], sa_eff); + row_slice[dst_idx + 1] = blend_u8(sg, row_slice[dst_idx + 1], sa_eff); + row_slice[dst_idx + 2] = blend_u8(sb, row_slice[dst_idx + 2], sa_eff); + let da = u16::from(row_slice[dst_idx + 3]); + row_slice[dst_idx + 3] = (sa_eff + ((da * (255 - sa_eff) + 128) >> 8)).min(255) as u8; + } + } +} + +// ── Rotated blitting ──────────────────────────────────────────────────────── + +/// Scale and blit a source RGBA8 buffer onto a destination RGBA8 buffer at the +/// given destination rectangle with clockwise rotation around the rect centre. +/// +/// The source is stretched to fill the destination rect (no aspect-ratio- +/// preserving fit). Aspect ratio handling is the responsibility of the +/// caller / presentation layer. The stretched content is then rotated +/// around the rect centre. +/// +/// Uses inverse-affine mapping with nearest-neighbor sampling. Edge pixels +/// receive fractional alpha coverage computed from the signed distance to each +/// of the four rect edges in the un-rotated local coordinate system. This +/// eliminates the staircase aliasing that a hard binary inside/outside test +/// would produce. +/// +/// AVX2 inner loop for rotated blitting — processes 8 interior pixels at a time. +/// +/// Gathers source pixels by stepping through rotated coordinates, then blends +/// with the appropriate opaque/alpha SIMD path. Returns the number of pixels +/// processed; `local_x`/`local_y` are updated in-place via `&mut`. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +#[allow( + clippy::too_many_arguments, + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::similar_names +)] +unsafe fn rotated_blit_avx2_loop( + row_slice: &mut [u8], + src: &[u8], + px: i32, + skip_u: usize, + local_x: &mut f32, + local_y: &mut f32, + cos_a: f32, + sin_a: f32, + half_cw: f32, + half_ch: f32, + inv_scale_x: f32, + inv_scale_y: f32, + sw: usize, + sh: usize, + opacity_u16: u16, + mirror_h: bool, + mirror_v: bool, +) -> usize { + let mut done = 0usize; + while done + 8 <= skip_u { + let mut src_pixels = [0u32; 8]; + let mut all_valid = true; + let snap_local_x = *local_x; + let snap_local_y = *local_y; + for sp in &mut src_pixels { + *local_x += cos_a; + *local_y -= sin_a; + let isx_raw = (((*local_x + half_cw) * inv_scale_x) as usize).min(sw - 1); + let isy_raw = (((*local_y + half_ch) * inv_scale_y) as usize).min(sh - 1); + let isx = if mirror_h { sw - 1 - isx_raw } else { isx_raw }; + let isy = if mirror_v { sh - 1 - isy_raw } else { isy_raw }; + let si = (isy * sw + isx) * 4; + if si + 3 < src.len() { + *sp = read_rgba_u32(src, si); + } else { + all_valid = false; + break; + } + } + + if !all_valid { + *local_x = snap_local_x; + *local_y = snap_local_y; + break; + } + + let dst_off = (px as usize + 1 + done) * 4; + if dst_off + 31 < row_slice.len() { + let dst_ptr = row_slice.as_mut_ptr().add(dst_off); + if opacity_u16 >= 256 { + blend_8px_opaque_avx2(dst_ptr, src_pixels); + } else { + blend_8px_alpha_avx2(dst_ptr, src_pixels, opacity_u16); + } + } + + done += 8; + } + done +} + +/// For near-zero rotation angles (< 0.01°), a fast path delegates directly +/// to [`scale_blit_rgba`] which performs the same stretch-to-fill without +/// the rotation overhead. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::too_many_arguments, + clippy::cast_precision_loss, + clippy::similar_names, + clippy::cast_possible_wrap, + // AVX2 block has side-effects (SIMD writes) before assigning done. + clippy::useless_let_if_seq +)] +pub fn scale_blit_rgba_rotated( + dst: &mut [u8], + dst_width: u32, + dst_height: u32, + src: &[u8], + src_width: u32, + src_height: u32, + dst_rect: &Rect, + opacity: f32, + rotation_deg: f32, + src_opaque: bool, + mirror_h: bool, + mirror_v: bool, +) { + if src_width == 0 || src_height == 0 || dst_rect.width == 0 || dst_rect.height == 0 { + return; + } + + let rw = dst_rect.width as f32; + let rh = dst_rect.height as f32; + + // ── Near-zero rotation fast path ────────────────────────────────── + // Delegate to the optimised non-rotated blit which stretches the + // source to fill the destination rect (no aspect-ratio fitting). + if rotation_deg.abs() < 0.01 { + scale_blit_rgba( + dst, dst_width, dst_height, src, src_width, src_height, dst_rect, opacity, src_opaque, + mirror_h, mirror_v, + ); + return; + } + + let dw = dst_width.cast_signed(); + let dh = dst_height.cast_signed(); + let sw = src_width as usize; + let sh = src_height as usize; + + // Pre-compute sin/cos for the rotation (needed for the bounding-box + // computation and for the per-pixel inverse mapping). + let angle_rad = rotation_deg.to_radians(); + let cos_a = angle_rad.cos(); + let sin_a = angle_rad.sin(); + + // ── Stretch-to-fill scaling ────────────────────────────────────── + // The source is stretched to fill the destination rect (no + // aspect-ratio-preserving fit). Aspect ratio handling is the + // responsibility of the client / presentation layer. + let half_cw = rw * 0.5; + let half_ch = rh * 0.5; + let inv_scale_x = src_width as f32 / rw; + let inv_scale_y = src_height as f32 / rh; + + // Rotation centre = centre of the destination rect. + let cx = rw.mul_add(0.5, dst_rect.x as f32); + let cy = rh.mul_add(0.5, dst_rect.y as f32); + + // Compute the axis-aligned bounding box of the rotated *content* area + // (not the full rect) so we only iterate over pixels that could + // possibly be covered by actual source content. + let corners = + [(-half_cw, -half_ch), (half_cw, -half_ch), (half_cw, half_ch), (-half_cw, half_ch)]; + let mut min_x = f32::MAX; + let mut max_x = f32::MIN; + let mut min_y = f32::MAX; + let mut max_y = f32::MIN; + for (lx, ly) in &corners { + let rx = lx * cos_a - ly * sin_a + cx; + let ry = lx * sin_a + ly * cos_a + cy; + min_x = min_x.min(rx); + max_x = max_x.max(rx); + min_y = min_y.min(ry); + max_y = max_y.max(ry); + } + + // Expand bounding box by 1px on each side so the AA fringe is included. + let bb_x0 = ((min_x.floor() as i32) - 1).max(0); + let bb_y0 = ((min_y.floor() as i32) - 1).max(0); + let bb_x1 = ((max_x.ceil() as i32) + 1).min(dw); + let bb_y1 = ((max_y.ceil() as i32) + 1).min(dh); + + let row_stride = dst_width as usize * 4; + + // Pre-compute opacity as a 0..255 integer multiplier. + let opacity_u16 = if opacity < 1.0 { + opacity.mul_add(255.0, 0.5) as u16 + } else { + 256 // sentinel: means "fully opaque, skip per-pixel multiply" + }; + + // Per-row closure that processes all columns in a single row of the + // bounding box. Uses an incremental stepper: since `dx_f` increments + // by 1.0 each column, `local_x` and `local_y` change by `+cos_a` and + // `-sin_a` respectively — replacing 2 multiplies with 2 adds per pixel. + // + // Edge anti-aliasing distances are computed against the rect boundary + // (`half_cw` × `half_ch`), so the visible edge matches the rect. + // + // For interior pixels where `min_dist >= 1.0` the edge-coverage clamp + // is a no-op, so we skip the coverage math entirely for the bulk of + // each span. + let process_row = |py: i32, row_slice: &mut [u8]| { + let dy = py as f32 - cy; + + // Seed the stepper at the first column of the bounding box. + let dx_f0 = bb_x0 as f32 - cx; + let mut local_x = dx_f0 * cos_a + dy * sin_a; + let mut local_y = (-dx_f0).mul_add(sin_a, dy * cos_a); + + let mut px = bb_x0; + while px < bb_x1 { + // ── Edge anti-aliasing via signed distance ────────────── + // Distances are relative to the content boundary, not the + // full destination rect. + let d_left = local_x + half_cw; + let d_right = half_cw - local_x; + let d_top = local_y + half_ch; + let d_bottom = half_ch - local_y; + let min_dist = d_left.min(d_right).min(d_top).min(d_bottom); + + if min_dist <= 0.0 { + // Fully outside content area — step and continue. + local_x += cos_a; + local_y -= sin_a; + px += 1; + continue; + } + + // Map from rect-local coords to source pixel coords. + // `local_x/y` ∈ [-half_cw, half_cw] × [-half_ch, half_ch] + // for points inside the rect. Convert to source pixel + // space via the per-axis inverse scale. + let src_fx = (local_x + half_cw) * inv_scale_x; + let src_fy = (local_y + half_ch) * inv_scale_y; + + let sxi_raw = (src_fx as usize).min(sw - 1); + let syi_raw = (src_fy as usize).min(sh - 1); + let sxi = if mirror_h { sw - 1 - sxi_raw } else { sxi_raw }; + let syi = if mirror_v { sh - 1 - syi_raw } else { syi_raw }; + + let src_idx = (syi * sw + sxi) * 4; + if src_idx + 3 >= src.len() { + local_x += cos_a; + local_y -= sin_a; + px += 1; + continue; + } + + let sr = src[src_idx]; + let sg = src[src_idx + 1]; + let sb = src[src_idx + 2]; + let mut sa = src[src_idx + 3]; + + // Apply layer opacity. + if opacity_u16 < 256 { + sa = ((u16::from(sa) * opacity_u16 + 128) >> 8).min(255) as u8; + } + + // Apply edge coverage only when near a border. + if min_dist < 1.0 { + sa = f32::from(sa).mul_add(min_dist, 0.5) as u8; + } + + if sa > 0 { + let dst_off = px as usize * 4; + if dst_off + 3 < row_slice.len() { + if sa == 255 { + row_slice[dst_off] = sr; + row_slice[dst_off + 1] = sg; + row_slice[dst_off + 2] = sb; + row_slice[dst_off + 3] = 255; + } else { + let a16 = u16::from(sa); + row_slice[dst_off] = blend_u8(sr, row_slice[dst_off], a16); + row_slice[dst_off + 1] = blend_u8(sg, row_slice[dst_off + 1], a16); + row_slice[dst_off + 2] = blend_u8(sb, row_slice[dst_off + 2], a16); + let da = u16::from(row_slice[dst_off + 3]); + row_slice[dst_off + 3] = + (a16 + ((da * (255 - a16) + 128) >> 8)).min(255) as u8; + } + } + } + + // Interior fast-forward: if we are well inside the content + // area (min_dist >= 2.0), subsequent pixels will also be + // interior until we approach an edge. + // + // The minimum distance decreases by at most 1.0 per column + // step (the directional derivative of each edge distance w.r.t. + // the column step is at most ±1 since |cos_a|, |sin_a| ≤ 1). + // So if min_dist >= 2.0, at least the next pixel is also fully + // interior (min_dist ≥ 1.0). We use this to batch interior + // pixels with a tighter loop that skips the coverage branch. + if min_dist >= 2.0 { + // Number of pixels we can safely process without AA. + // Conservative: (min_dist - 1.0).floor() guarantees + // min_dist stays >= 1.0 for all skipped pixels. + let skip = ((min_dist - 1.0).floor() as i32).min(bb_x1 - px - 1); + if skip > 0 { + let skip_u = skip as usize; + + // ── SIMD batched path: AVX2 (8px) → SSE2 (4px) → scalar ── + #[cfg(target_arch = "x86_64")] + { + let mut done = 0usize; + + // AVX2: process groups of 8 interior pixels. + if is_x86_feature_detected!("avx2") { + done = unsafe { + rotated_blit_avx2_loop( + row_slice, + src, + px, + skip_u, + &mut local_x, + &mut local_y, + cos_a, + sin_a, + half_cw, + half_ch, + inv_scale_x, + inv_scale_y, + sw, + sh, + opacity_u16, + mirror_h, + mirror_v, + ) + }; + } + + // SSE2: process remaining pixels in groups of 4. + while done + 4 <= skip_u { + let mut src_pixels = [0u32; 4]; + let mut all_valid = true; + let snap_local_x = local_x; + let snap_local_y = local_y; + for sp in &mut src_pixels { + local_x += cos_a; + local_y -= sin_a; + let isx_raw = + (((local_x + half_cw) * inv_scale_x) as usize).min(sw - 1); + let isy_raw = + (((local_y + half_ch) * inv_scale_y) as usize).min(sh - 1); + let isx = if mirror_h { sw - 1 - isx_raw } else { isx_raw }; + let isy = if mirror_v { sh - 1 - isy_raw } else { isy_raw }; + let si = (isy * sw + isx) * 4; + if si + 3 < src.len() { + *sp = unsafe { read_rgba_u32(src, si) }; + } else { + all_valid = false; + break; + } + } + + if !all_valid { + local_x = snap_local_x; + local_y = snap_local_y; + break; + } + + let dst_off = (px as usize + 1 + done) * 4; + if dst_off + 15 < row_slice.len() { + unsafe { + let dst_ptr = row_slice.as_mut_ptr().add(dst_off); + if opacity_u16 >= 256 { + blend_4px_opaque_sse2(dst_ptr, src_pixels); + } else { + blend_4px_alpha_sse2(dst_ptr, src_pixels, opacity_u16); + } + } + } + + done += 4; + } + + // Advance px by the number of pixels handled above. + #[allow(clippy::cast_possible_wrap)] + { + px += done as i32; + } + + // Scalar remainder for leftover pixels. + for _ in done..skip_u { + local_x += cos_a; + local_y -= sin_a; + px += 1; + + let isx_raw = + (((local_x + half_cw) * inv_scale_x) as usize).min(sw - 1); + let isy_raw = + (((local_y + half_ch) * inv_scale_y) as usize).min(sh - 1); + let isx = if mirror_h { sw - 1 - isx_raw } else { isx_raw }; + let isy = if mirror_v { sh - 1 - isy_raw } else { isy_raw }; + let si = (isy * sw + isx) * 4; + if si + 3 >= src.len() { + continue; + } + + blend_pixel_scalar(row_slice, px as usize * 4, src, si, opacity_u16); + } + } + + // ── Non-x86_64 fallback: scalar loop ── + #[cfg(not(target_arch = "x86_64"))] + { + for _ in 0..skip_u { + local_x += cos_a; + local_y -= sin_a; + px += 1; + + let isx_raw = + (((local_x + half_cw) * inv_scale_x) as usize).min(sw - 1); + let isy_raw = + (((local_y + half_ch) * inv_scale_y) as usize).min(sh - 1); + let isx = if mirror_h { sw - 1 - isx_raw } else { isx_raw }; + let isy = if mirror_v { sh - 1 - isy_raw } else { isy_raw }; + let si = (isy * sw + isx) * 4; + if si + 3 >= src.len() { + continue; + } + + blend_pixel_scalar(row_slice, px as usize * 4, src, si, opacity_u16); + } + } + } + } + + local_x += cos_a; + local_y -= sin_a; + px += 1; + } + }; + + // Early-out when the bounding box is empty (rect entirely off-screen). + if bb_y1 <= bb_y0 || bb_x1 <= bb_x0 { + return; + } + + let bb_rows = (bb_y1 - bb_y0) as usize; + let first_row_byte = bb_y0 as usize * row_stride; + let dst_region = &mut dst[first_row_byte..first_row_byte + bb_rows * row_stride]; + + if bb_rows >= RAYON_ROW_THRESHOLD { + use rayon::prelude::*; + let chunk_rows = rayon_chunk_rows(bb_rows); + let chunk_bytes = row_stride * chunk_rows; + dst_region.par_chunks_mut(chunk_bytes).enumerate().for_each(|(chunk_idx, chunk)| { + let base_row = chunk_idx * chunk_rows; + for (j, row_slice) in chunk.chunks_mut(row_stride).enumerate() { + let row = base_row + j; + // `row` is bounded by `bb_rows` which derives from i32 bounding-box coords. + #[allow(clippy::cast_possible_wrap)] + process_row(bb_y0 + row as i32, row_slice); + } + }); + } else { + for (i, row_slice) in dst_region.chunks_mut(row_stride).enumerate() { + // `i` is bounded by `bb_rows` which derives from i32 bounding-box coords. + #[allow(clippy::cast_possible_wrap)] + process_row(bb_y0 + i as i32, row_slice); + } + } +} diff --git a/crates/nodes/src/video/compositor/pixel_ops/convert.rs b/crates/nodes/src/video/compositor/pixel_ops/convert.rs new file mode 100644 index 00000000..0ef8403d --- /dev/null +++ b/crates/nodes/src/video/compositor/pixel_ops/convert.rs @@ -0,0 +1,685 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Colour-space conversion between RGBA8 and YUV 4:2:0 formats (I420, NV12). +//! +//! All conversions use BT.601 coefficients and row-level parallelism via +//! `rayon` when the image is large enough. On x86-64, SIMD-accelerated +//! kernels are dispatched at runtime via `is_x86_feature_detected!()`. + +use super::{rayon_chunk_rows, RAYON_ROW_THRESHOLD}; + +#[cfg(target_arch = "x86_64")] +use super::simd; + +// ── Shared rayon parallelization helper ───────────────────────────────────── + +/// Process `total_rows` of a buffer in parallel (or sequentially for small +/// images), invoking `process_row(row_index, row_slice)` for each row. +/// +/// This eliminates the ~20-line rayon boilerplate that was previously +/// duplicated across every public conversion function. +fn parallel_rows( + buf: &mut [u8], + row_stride: usize, + total_rows: usize, + process_row: impl Fn(usize, &mut [u8]) + Send + Sync, +) { + use rayon::prelude::*; + + if total_rows >= RAYON_ROW_THRESHOLD { + let chunk_rows = rayon_chunk_rows(total_rows); + let chunk_bytes = row_stride * chunk_rows; + buf.par_chunks_mut(chunk_bytes).enumerate().for_each(|(chunk_idx, chunk)| { + let base_row = chunk_idx * chunk_rows; + for (j, row) in chunk.chunks_mut(row_stride).enumerate() { + let row_idx = base_row + j; + if row_idx >= total_rows { + break; + } + process_row(row_idx, row); + } + }); + } else { + for (row_idx, row) in buf.chunks_mut(row_stride).take(total_rows).enumerate() { + process_row(row_idx, row); + } + } +} + +// ── Shared Y-row conversion helper ────────────────────────────────────────── + +/// Convert a single luma (Y) row from packed RGBA8 source data using BT.601 +/// coefficients, with a SIMD dispatch cascade (AVX2 → SSE4.1 → SSE2 → scalar). +/// +/// This is the inner kernel shared by [`rgba8_to_i420_buf`] and +/// [`rgba8_to_nv12_buf`]. It is `#[inline(always)]` so that the CPU feature +/// flags — which are hoisted out of the per-row loop in each caller — are +/// propagated as constants and the SIMD branches fold away at compile time. +#[inline(always)] +#[allow(clippy::inline_always)] // Required: CPU feature flags must be constant-folded for SIMD branch elimination +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::many_single_char_names)] +fn convert_y_row( + data: &[u8], + row: usize, + w: usize, + y_row: &mut [u8], + #[cfg(target_arch = "x86_64")] use_avx2: bool, + #[cfg(target_arch = "x86_64")] use_sse41: bool, + #[cfg(target_arch = "x86_64")] use_sse2: bool, +) { + let rgba_base = row * w * 4; + let mut start_col = 0usize; + + #[cfg(target_arch = "x86_64")] + { + if use_avx2 { + start_col = + unsafe { simd::rgba8_to_y_row_avx2(&data[rgba_base..rgba_base + w * 4], y_row, w) }; + if start_col < w && use_sse41 { + let tail = unsafe { + simd::rgba8_to_y_row_sse41( + &data[rgba_base + start_col * 4..rgba_base + w * 4], + &mut y_row[start_col..], + w - start_col, + ) + }; + start_col += tail; + } + } else if use_sse41 { + start_col = unsafe { + simd::rgba8_to_y_row_sse41(&data[rgba_base..rgba_base + w * 4], y_row, w) + }; + } else if use_sse2 { + start_col = + unsafe { simd::rgba8_to_y_row_sse2(&data[rgba_base..rgba_base + w * 4], y_row, w) }; + } + } + + for (col, y_out) in y_row.iter_mut().enumerate().take(w).skip(start_col) { + let off = rgba_base + col * 4; + let r = i32::from(data[off]); + let g = i32::from(data[off + 1]); + let b = i32::from(data[off + 2]); + let y = ((66 * r + 129 * g + 25 * b + 128) >> 8) + 16; + *y_out = y.clamp(0, 255) as u8; + } +} + +// ── I420 → RGBA8 ──────────────────────────────────────────────────────────── + +/// Convert an I420 (YUV 4:2:0 planar) buffer to RGBA8, writing into `out`. +/// +/// The caller must ensure `out` has length >= `width * height * 4`. +/// Rows are processed in parallel via `rayon`. +/// +/// On x86-64 with SSE2 support the inner per-row loop is vectorised to +/// process 8 pixels per iteration, falling back to scalar for tail pixels. +/// +/// **Note:** This function assumes a *packed* I420 layout (luma stride = width, +/// chroma stride = ceil(width/2)). If non-packed / aligned layouts are introduced +/// in the future, a stride-aware variant should be added. +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::many_single_char_names)] +pub fn i420_to_rgba8_buf(data: &[u8], width: u32, height: u32, out: &mut [u8]) { + let w = width as usize; + let h = height as usize; + let y_stride = w; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let u_offset = y_stride * h; + let v_offset = u_offset + chroma_w * chroma_h; + let rgba_row_stride = w * 4; + + #[cfg(target_arch = "x86_64")] + let use_avx2 = is_x86_feature_detected!("avx2"); + #[cfg(target_arch = "x86_64")] + let use_sse41 = is_x86_feature_detected!("sse4.1"); + #[cfg(target_arch = "x86_64")] + let use_sse2 = is_x86_feature_detected!("sse2"); + + let convert_row = |row: usize, rgba_row: &mut [u8]| { + let y_base = row * y_stride; + let chroma_row = row / 2; + let u_base = u_offset + chroma_row * chroma_w; + let v_base = v_offset + chroma_row * chroma_w; + + let mut start_col = 0usize; + + #[cfg(target_arch = "x86_64")] + { + if use_avx2 { + start_col = unsafe { + simd::i420_to_rgba8_row_avx2( + &data[y_base..y_base + w], + &data[u_base..u_base + chroma_w], + &data[v_base..v_base + chroma_w], + rgba_row, + w, + ) + }; + if start_col < w && use_sse41 { + let tail = unsafe { + simd::i420_to_rgba8_row_sse41( + &data[y_base + start_col..y_base + w], + &data[u_base + start_col / 2..u_base + chroma_w], + &data[v_base + start_col / 2..v_base + chroma_w], + &mut rgba_row[start_col * 4..], + w - start_col, + ) + }; + start_col += tail; + } + } else if use_sse41 { + start_col = unsafe { + simd::i420_to_rgba8_row_sse41( + &data[y_base..y_base + w], + &data[u_base..u_base + chroma_w], + &data[v_base..v_base + chroma_w], + rgba_row, + w, + ) + }; + } else if use_sse2 { + start_col = unsafe { + simd::i420_to_rgba8_row_sse2( + &data[y_base..y_base + w], + &data[u_base..u_base + chroma_w], + &data[v_base..v_base + chroma_w], + rgba_row, + w, + ) + }; + } + } + + // Scalar tail (or full row on non-x86-64 / without SSE2). + for col in start_col..w { + let y_val = i32::from(data[y_base + col]); + let u_val = i32::from(data[u_base + col / 2]); + let v_val = i32::from(data[v_base + col / 2]); + + let c = y_val - 16; + let d = u_val - 128; + let e = v_val - 128; + + let off = col * 4; + rgba_row[off] = ((298 * c + 409 * e + 128) >> 8).clamp(0, 255) as u8; + rgba_row[off + 1] = ((298 * c - 100 * d - 208 * e + 128) >> 8).clamp(0, 255) as u8; + rgba_row[off + 2] = ((298 * c + 516 * d + 128) >> 8).clamp(0, 255) as u8; + rgba_row[off + 3] = 255; + } + }; + + parallel_rows(&mut out[..w * h * 4], rgba_row_stride, h, convert_row); +} + +// ── NV12 → RGBA8 ──────────────────────────────────────────────────────────── + +/// Convert an NV12 (Y + interleaved UV) buffer to RGBA8, writing into `out`. +/// +/// Same BT.601 math as [`i420_to_rgba8_buf`], but reads U and V from a single +/// interleaved UV plane instead of two separate planes. Uses a dedicated +/// NV12 SSE2 kernel that reads the interleaved UV data in-place — no +/// scratch-buffer deinterleaving or thread-local storage required. +/// +/// The caller must ensure `out` has length >= `width * height * 4`. +/// Input `data` must be a packed NV12 buffer: `width * height` luma bytes +/// followed by `ceil(width/2) * 2 * ceil(height/2)` interleaved UV bytes. +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::many_single_char_names)] +pub fn nv12_to_rgba8_buf(data: &[u8], width: u32, height: u32, out: &mut [u8]) { + let w = width as usize; + let h = height as usize; + let y_stride = w; + let chroma_w = w.div_ceil(2); + let uv_stride = chroma_w * 2; // interleaved UV pairs + let uv_offset = y_stride * h; + let rgba_row_stride = w * 4; + + #[cfg(target_arch = "x86_64")] + let use_avx2 = is_x86_feature_detected!("avx2"); + #[cfg(target_arch = "x86_64")] + let use_sse41 = is_x86_feature_detected!("sse4.1"); + #[cfg(target_arch = "x86_64")] + let use_sse2 = is_x86_feature_detected!("sse2"); + + let convert_row = |row: usize, rgba_row: &mut [u8]| { + let y_base = row * y_stride; + let chroma_row = row / 2; + let uv_base = uv_offset + chroma_row * uv_stride; + + let mut start_col = 0usize; + + #[cfg(target_arch = "x86_64")] + { + if use_avx2 { + start_col = unsafe { + simd::nv12_to_rgba8_row_avx2( + &data[y_base..y_base + w], + &data[uv_base..uv_base + uv_stride], + rgba_row, + w, + ) + }; + if start_col < w && use_sse41 { + start_col += unsafe { + simd::nv12_to_rgba8_row_sse41( + &data[y_base + start_col..y_base + w], + &data[uv_base + (start_col / 2) * 2..uv_base + uv_stride], + &mut rgba_row[start_col * 4..], + w - start_col, + ) + }; + } + } else if use_sse41 { + start_col = unsafe { + simd::nv12_to_rgba8_row_sse41( + &data[y_base..y_base + w], + &data[uv_base..uv_base + uv_stride], + rgba_row, + w, + ) + }; + } else if use_sse2 { + start_col = unsafe { + simd::nv12_to_rgba8_row_sse2( + &data[y_base..y_base + w], + &data[uv_base..uv_base + uv_stride], + rgba_row, + w, + ) + }; + } + } + + // Scalar tail (or full row on non-x86-64 / without SSE2). + for col in start_col..w { + let y_val = i32::from(data[y_base + col]); + let u_val = i32::from(data[uv_base + (col / 2) * 2]); + let v_val = i32::from(data[uv_base + (col / 2) * 2 + 1]); + + let c = y_val - 16; + let d = u_val - 128; + let e = v_val - 128; + + let off = col * 4; + rgba_row[off] = ((298 * c + 409 * e + 128) >> 8).clamp(0, 255) as u8; + rgba_row[off + 1] = ((298 * c - 100 * d - 208 * e + 128) >> 8).clamp(0, 255) as u8; + rgba_row[off + 2] = ((298 * c + 516 * d + 128) >> 8).clamp(0, 255) as u8; + rgba_row[off + 3] = 255; + } + }; + + parallel_rows(&mut out[..w * h * 4], rgba_row_stride, h, convert_row); +} + +// ── RGBA8 → I420 ──────────────────────────────────────────────────────────── + +/// Convert an RGBA8 buffer to I420 (YUV 4:2:0 planar), writing into `out`. +/// +/// The caller must ensure `out` has length >= `w * h + 2 * ((w+1)/2) * ((h+1)/2)`. +/// +/// Uses a **single fused pass** over chroma-row pairs: each iteration converts +/// Y for both luma rows AND chroma for the pair while the RGBA data is still +/// hot in L1/L2 cache. This halves the RGBA memory reads compared to the +/// two-pass approach (separate Y-plane + chroma passes). +/// +/// **Note:** This function assumes a *packed* RGBA8 layout (stride = width * 4) +/// and writes a packed I420 output (luma stride = width, chroma stride = ceil(width/2)). +/// If non-packed / aligned layouts are introduced in the future, a stride-aware +/// variant should be added. +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::many_single_char_names)] +pub fn rgba8_to_i420_buf(data: &[u8], width: u32, height: u32, out: &mut [u8]) { + use rayon::prelude::*; + + let w = width as usize; + let h = height as usize; + let y_stride = w; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let y_size = y_stride * h; + let chroma_size = chroma_w * chroma_h; + + // Split output into Y and chroma planes. + let (y_plane, chroma_planes) = out[..y_size + 2 * chroma_size].split_at_mut(y_size); + let (u_plane, v_plane) = chroma_planes.split_at_mut(chroma_size); + + // Hoist CPU feature detection once, outside the per-row closures. + #[cfg(target_arch = "x86_64")] + let use_avx2 = is_x86_feature_detected!("avx2"); + #[cfg(target_arch = "x86_64")] + let use_sse41 = is_x86_feature_detected!("sse4.1"); + #[cfg(target_arch = "x86_64")] + let use_sse2 = is_x86_feature_detected!("sse2"); + + // Chroma-row conversion closure. + let convert_chroma_row = |crow: usize, u_row: &mut [u8], v_row: &mut [u8]| { + let r0 = crow * 2; + let mut start_ccol = 0usize; + + #[cfg(target_arch = "x86_64")] + { + if r0 + 1 < h { + let row0_start = r0 * w * 4; + let row1_start = (r0 + 1) * w * 4; + let rgba_row0 = &data[row0_start..row0_start + w * 4]; + let rgba_row1 = &data[row1_start..row1_start + w * 4]; + + if use_avx2 { + start_ccol = unsafe { + simd::rgba8_to_chroma_row_avx2( + rgba_row0, rgba_row1, u_row, v_row, chroma_w, w, + ) + }; + } + if start_ccol < chroma_w && use_sse2 { + start_ccol += unsafe { + simd::rgba8_to_chroma_row_sse2( + &rgba_row0[start_ccol * 2 * 4..], + &rgba_row1[start_ccol * 2 * 4..], + &mut u_row[start_ccol..], + &mut v_row[start_ccol..], + chroma_w - start_ccol, + w - start_ccol * 2, + ) + }; + } + } + } + + for ccol in start_ccol..chroma_w { + let c0 = ccol * 2; + let mut sr = 0i32; + let mut sg = 0i32; + let mut sb = 0i32; + let mut count = 0i32; + for dr in 0..2 { + let rr = r0 + dr; + if rr >= h { + continue; + } + for dc in 0..2 { + let cc = c0 + dc; + if cc < w { + let off = (rr * w + cc) * 4; + sr += i32::from(data[off]); + sg += i32::from(data[off + 1]); + sb += i32::from(data[off + 2]); + count += 1; + } + } + } + let r = sr / count; + let g = sg / count; + let b = sb / count; + let u = ((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128; + let v = ((112 * r - 94 * g - 18 * b + 128) >> 8) + 128; + u_row[ccol] = u.clamp(0, 255) as u8; + v_row[ccol] = v.clamp(0, 255) as u8; + } + }; + + // Raw address of the Y plane for concurrent access from parallel tasks. + // See the NV12 variant for the full safety argument. + let y_base_addr = y_plane.as_mut_ptr() as usize; + let y_len = y_plane.len(); + + // Fused row-pair closure: convert Y for both luma rows AND chroma for + // the pair in a single pass, keeping RGBA data hot in cache. + let process_row_pair = |crow: usize, u_row: &mut [u8], v_row: &mut [u8]| { + let r0 = crow * 2; + let r1 = r0 + 1; + + // Convert Y for row r0. + let y_offset_0 = r0 * y_stride; + if y_offset_0 < y_len { + let row_len = y_stride.min(y_len - y_offset_0); + let y_row_0 = unsafe { + std::slice::from_raw_parts_mut((y_base_addr + y_offset_0) as *mut u8, row_len) + }; + convert_y_row( + data, + r0, + w, + y_row_0, + #[cfg(target_arch = "x86_64")] + use_avx2, + #[cfg(target_arch = "x86_64")] + use_sse41, + #[cfg(target_arch = "x86_64")] + use_sse2, + ); + } + + // Convert Y for row r1 (if it exists — handles odd heights). + if r1 < h { + let y_offset_1 = r1 * y_stride; + if y_offset_1 < y_len { + let row_len = y_stride.min(y_len - y_offset_1); + let y_row_1 = unsafe { + std::slice::from_raw_parts_mut((y_base_addr + y_offset_1) as *mut u8, row_len) + }; + convert_y_row( + data, + r1, + w, + y_row_1, + #[cfg(target_arch = "x86_64")] + use_avx2, + #[cfg(target_arch = "x86_64")] + use_sse41, + #[cfg(target_arch = "x86_64")] + use_sse2, + ); + } + } + + // Convert chroma for the row pair. + convert_chroma_row(crow, u_row, v_row); + }; + + // Parallelise by chroma row — each task processes one row-pair. + let u_rows: Vec<&mut [u8]> = u_plane.chunks_mut(chroma_w).collect(); + let v_rows: Vec<&mut [u8]> = v_plane.chunks_mut(chroma_w).collect(); + + if chroma_h >= RAYON_ROW_THRESHOLD / 2 { + u_rows.into_par_iter().zip(v_rows).enumerate().for_each(|(crow, (u_row, v_row))| { + process_row_pair(crow, u_row, v_row); + }); + } else { + for (crow, (u_row, v_row)) in u_rows.into_iter().zip(v_rows).enumerate() { + process_row_pair(crow, u_row, v_row); + } + } +} + +// ── RGBA8 → NV12 ──────────────────────────────────────────────────────────── + +/// Convert an RGBA8 buffer to NV12 (Y + interleaved UV), writing into `out`. +/// +/// The caller must ensure `out` has length >= `w * h + ceil(w/2) * 2 * ceil(h/2)`. +/// +/// Uses a **single fused pass** over chroma-row pairs: each iteration converts +/// Y for both luma rows AND chroma for the pair while the RGBA data is still +/// hot in L1/L2 cache. This halves the RGBA memory reads compared to the +/// two-pass approach (separate Y-plane + chroma passes). +/// +/// **Note:** Assumes packed RGBA8 input (stride = width * 4) and writes a +/// packed NV12 output (luma stride = width, chroma stride = ceil(width/2) * 2). +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::many_single_char_names)] +pub fn rgba8_to_nv12_buf(data: &[u8], width: u32, height: u32, out: &mut [u8]) { + use rayon::prelude::*; + + let w = width as usize; + let h = height as usize; + let y_stride = w; + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let y_size = y_stride * h; + let uv_stride = chroma_w * 2; + + // Split output into Y plane and UV plane. + let (y_plane, uv_plane) = out[..y_size + uv_stride * chroma_h].split_at_mut(y_size); + + // Hoist CPU feature detection once, outside the per-row closures. + #[cfg(target_arch = "x86_64")] + let use_avx2 = is_x86_feature_detected!("avx2"); + #[cfg(target_arch = "x86_64")] + let use_sse41 = is_x86_feature_detected!("sse4.1"); + #[cfg(target_arch = "x86_64")] + let use_sse2 = is_x86_feature_detected!("sse2"); + + // Chroma-row conversion closure. + let convert_chroma_row = |crow: usize, uv_row: &mut [u8]| { + let r0 = crow * 2; + let mut start_ccol = 0usize; + + #[cfg(target_arch = "x86_64")] + { + if r0 + 1 < h { + let row0_start = r0 * w * 4; + let row1_start = (r0 + 1) * w * 4; + let rgba_row0 = &data[row0_start..row0_start + w * 4]; + let rgba_row1 = &data[row1_start..row1_start + w * 4]; + + if use_avx2 { + start_ccol = unsafe { + simd::rgba8_to_chroma_row_nv12_avx2( + rgba_row0, rgba_row1, uv_row, chroma_w, w, + ) + }; + } + if start_ccol < chroma_w && use_sse2 { + start_ccol += unsafe { + simd::rgba8_to_chroma_row_nv12_sse2( + &rgba_row0[start_ccol * 2 * 4..], + &rgba_row1[start_ccol * 2 * 4..], + &mut uv_row[start_ccol * 2..], + chroma_w - start_ccol, + w - start_ccol * 2, + ) + }; + } + } + } + + for ccol in start_ccol..chroma_w { + let c0 = ccol * 2; + let mut sr = 0i32; + let mut sg = 0i32; + let mut sb = 0i32; + let mut count = 0i32; + for dr in 0..2 { + let rr = r0 + dr; + if rr >= h { + continue; + } + for dc in 0..2 { + let cc = c0 + dc; + if cc < w { + let off = (rr * w + cc) * 4; + sr += i32::from(data[off]); + sg += i32::from(data[off + 1]); + sb += i32::from(data[off + 2]); + count += 1; + } + } + } + let r = sr / count; + let g = sg / count; + let b = sb / count; + let u = ((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128; + let v = ((112 * r - 94 * g - 18 * b + 128) >> 8) + 128; + uv_row[ccol * 2] = u.clamp(0, 255) as u8; + uv_row[ccol * 2 + 1] = v.clamp(0, 255) as u8; + } + }; + + // Raw address of the Y plane for concurrent access from parallel tasks. + // + // SAFETY: each chroma-row task writes to disjoint Y-plane regions: + // task `crow` writes rows [2*crow] and [2*crow+1], which never overlap + // with another task's rows. `y_plane` and `uv_plane` are disjoint + // (from `split_at_mut` above). + // + // We store the pointer as `usize` (which is `Send + Sync`) and + // reconstruct slices inside the closure. The safety invariant + // (non-overlapping writes) is upheld by the row-pair index mapping. + let y_base_addr = y_plane.as_mut_ptr() as usize; + let y_len = y_plane.len(); + + // Fused row-pair closure: convert Y for both luma rows AND chroma for + // the pair in a single pass, keeping RGBA data hot in cache. + let process_row_pair = |crow: usize, uv_row: &mut [u8]| { + let r0 = crow * 2; + let r1 = r0 + 1; + + // Convert Y for row r0. + let y_offset_0 = r0 * y_stride; + if y_offset_0 < y_len { + let row_len = y_stride.min(y_len - y_offset_0); + // SAFETY: non-overlapping slice — see safety comment above. + let y_row_0 = unsafe { + std::slice::from_raw_parts_mut((y_base_addr + y_offset_0) as *mut u8, row_len) + }; + convert_y_row( + data, + r0, + w, + y_row_0, + #[cfg(target_arch = "x86_64")] + use_avx2, + #[cfg(target_arch = "x86_64")] + use_sse41, + #[cfg(target_arch = "x86_64")] + use_sse2, + ); + } + + // Convert Y for row r1 (if it exists — handles odd heights). + if r1 < h { + let y_offset_1 = r1 * y_stride; + if y_offset_1 < y_len { + let row_len = y_stride.min(y_len - y_offset_1); + let y_row_1 = unsafe { + std::slice::from_raw_parts_mut((y_base_addr + y_offset_1) as *mut u8, row_len) + }; + convert_y_row( + data, + r1, + w, + y_row_1, + #[cfg(target_arch = "x86_64")] + use_avx2, + #[cfg(target_arch = "x86_64")] + use_sse41, + #[cfg(target_arch = "x86_64")] + use_sse2, + ); + } + } + + // Convert chroma for the row pair (reads same RGBA rows as Y above). + convert_chroma_row(crow, uv_row); + }; + + // Parallelise by chroma row — each task processes one row-pair. + if chroma_h >= RAYON_ROW_THRESHOLD / 2 { + let chunk_rows = rayon_chunk_rows(chroma_h); + let chunk_bytes = uv_stride * chunk_rows; + uv_plane.par_chunks_mut(chunk_bytes).enumerate().for_each(|(chunk_idx, chunk)| { + let base_crow = chunk_idx * chunk_rows; + for (j, uv_row) in chunk.chunks_mut(uv_stride).enumerate() { + let crow = base_crow + j; + if crow >= chroma_h { + break; + } + process_row_pair(crow, uv_row); + } + }); + } else { + for (crow, uv_row) in uv_plane.chunks_mut(uv_stride).take(chroma_h).enumerate() { + process_row_pair(crow, uv_row); + } + } +} diff --git a/crates/nodes/src/video/compositor/pixel_ops/mod.rs b/crates/nodes/src/video/compositor/pixel_ops/mod.rs new file mode 100644 index 00000000..9f45c2df --- /dev/null +++ b/crates/nodes/src/video/compositor/pixel_ops/mod.rs @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Pixel-level operations for the video compositor. +//! +//! Contains RGBA8 blitting (with nearest-neighbor scaling), alpha blending, +//! overlay compositing, and I420 / NV12 ↔ RGBA8 colour-space conversion. +//! +//! All hot loops use row-level parallelism via `rayon` when the region is +//! large enough to amortise the thread-pool dispatch overhead. Below the +//! threshold the same per-row closures run sequentially. +//! +//! # Module structure +//! +//! - [`blit`] — axis-aligned and rotated scale + blit operations. +//! - [`convert`] — colour-space conversion (I420, NV12 ↔ RGBA8). +//! - [`simd`] (x86-64 only) — SIMD kernels for both blitting and conversion. + +mod blit; +mod convert; + +#[cfg(target_arch = "x86_64")] +mod simd_x86_64; + +/// Re-export the x86-64 SIMD module under a shorter name for internal use. +#[cfg(target_arch = "x86_64")] +use simd_x86_64 as simd; + +// ── Shared constants and helpers ──────────────────────────────────────────── + +/// Minimum number of output rows before we dispatch to rayon. Below this +/// threshold the per-row work is small enough that the rayon scheduling +/// overhead (work-stealing queue push/pop, thread wake-up) dominates. +/// 64 rows at 1280-wide RGBA8 ≈ 320 KiB — a reasonable crossover point +/// on modern x86-64 cores. +const RAYON_ROW_THRESHOLD: usize = 64; + +/// Number of rows to bundle into a single rayon task once parallel mode is +/// entered. Reduces work-stealing overhead from ~1 task/row to +/// ~rows/chunk tasks. +/// +/// [`rayon_chunk_rows`] auto-tunes the chunk size based on workload: +/// wider or taller frames produce fewer, larger chunks, keeping +/// scheduling cost proportional to the actual parallelism available. +/// +/// Formula: `max(8, total_rows / (num_cpus * 4))`, clamped to `[8, 64]`. +/// This keeps chunk counts proportional to hardware parallelism while +/// avoiding both excessive scheduling overhead (too many tiny chunks) +/// and poor load-balancing (too few large chunks). +/// +/// The CPU count is cached in a `LazyLock` so we avoid a `sysconf` syscall +/// (~40 µs on Linux) on every call. +fn rayon_chunk_rows(total_rows: usize) -> usize { + static CPUS: std::sync::LazyLock = std::sync::LazyLock::new(|| { + std::thread::available_parallelism().map(std::num::NonZero::get).unwrap_or(1) + }); + let ideal = total_rows.div_ceil(*CPUS * 4); + ideal.clamp(8, 64) +} + +/// Fixed-point alpha blend: `(src * alpha + dst * (255 - alpha) + 128) / 255` +/// using the well-known `((x + (x >> 8)) >> 8)` fast approximation of `x / 255`. +#[allow(clippy::inline_always)] +#[inline(always)] +const fn blend_u8(src: u8, dst: u8, alpha: u16) -> u8 { + let inv = 255 - alpha; + let val = src as u16 * alpha + dst as u16 * inv + 128; + ((val + (val >> 8)) >> 8) as u8 +} + +/// Check whether every pixel's alpha byte in an RGBA8 buffer is `0xFF`. +/// +/// Dispatches to AVX2 / SSE2 kernels on x86-64 (8 / 4 pixels per iteration) +/// and falls back to a scalar scan on other architectures. Safe wrapper +/// around the `target_feature`-gated SIMD helpers so callers outside this +/// module don't need their own `unsafe` + `cfg` scaffolding. +/// +/// Assumes `rgba.len()` is a multiple of 4 — always true for valid RGBA8 data. +pub fn all_alpha_opaque(rgba: &[u8]) -> bool { + #[cfg(target_arch = "x86_64")] + { + // SAFETY: `all_alpha_opaque_{avx2,sse2}` require the input length to be + // a multiple of 4 bytes, which holds for any valid RGBA8 buffer. Both + // helpers handle arbitrary tails internally (scalar fall-through for + // trailing bytes past the last full SIMD chunk). Feature availability + // is checked at runtime via `is_x86_feature_detected!`. + if is_x86_feature_detected!("avx2") { + unsafe { simd::all_alpha_opaque_avx2(rgba) } + } else { + unsafe { simd::all_alpha_opaque_sse2(rgba) } + } + } + #[cfg(not(target_arch = "x86_64"))] + { + rgba.chunks_exact(4).all(|px| px[3] == 255) + } +} + +// ── Public API re-exports ─────────────────────────────────────────────────── + +pub use blit::{scale_blit_rgba, scale_blit_rgba_rotated}; +pub use convert::{i420_to_rgba8_buf, nv12_to_rgba8_buf, rgba8_to_i420_buf, rgba8_to_nv12_buf}; diff --git a/crates/nodes/src/video/compositor/pixel_ops/simd_x86_64.rs b/crates/nodes/src/video/compositor/pixel_ops/simd_x86_64.rs new file mode 100644 index 00000000..7fa48671 --- /dev/null +++ b/crates/nodes/src/video/compositor/pixel_ops/simd_x86_64.rs @@ -0,0 +1,1734 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! x86-64 SIMD kernels for pixel-level operations. +//! +//! This module is only compiled on `x86_64` targets (gated at the module level +//! in `mod.rs`). It contains SSE2, SSE4.1 and AVX2 kernels for: +//! +//! - RGBA8 alpha blending (used by the blit functions) +//! - I420 / NV12 → RGBA8 colour-space conversion +//! - RGBA8 → I420 / NV12 colour-space conversion (Y-plane and chroma) +//! +//! SSE2 and SSE4.1 variants share the same algorithmic structure, differing +//! only in how they perform 32-bit integer multiplies. The `impl_yuv_to_rgba` +//! and `impl_rgba_to_y` macros generate both variants from a single body, +//! parameterised on the multiply strategy. + +// ── SSE2 alpha-blend helpers ──────────────────────────────────────────────── +// +// Process 4 RGBA pixels at a time using SSE2 integer arithmetic. +// Source pixels are gathered (non-contiguous via x_map), destination pixels +// are contiguous. The blend formula is identical to the scalar `blend_u8`: +// result = ((src*alpha + dst*(255-alpha) + 128) + ((…) >> 8)) >> 8 +// +// For the alpha channel we set source-alpha to 255 before blending so that +// `blend_u8(255, dst_alpha, src_alpha)` naturally computes the standard +// over-composite alpha `a_src + a_dst*(1-a_src)` (within ±1 of the scalar +// approximation — both are approximate divisions by 255). + +/// Read 4 bytes from `src` at `offset` as a native-endian `u32`. +/// +/// # Safety +/// +/// Caller must ensure `offset + 3 < src.len()`. +#[inline] +pub(super) const unsafe fn read_rgba_u32(src: &[u8], offset: usize) -> u32 { + std::ptr::read_unaligned(src.as_ptr().add(offset).cast::()) +} + +/// Blend 4 gathered source RGBA pixels onto 4 contiguous destination pixels +/// using SSE2 "over" compositing (no opacity modifier). +/// +/// # Safety +/// +/// `dst_ptr` must point to at least 16 writable bytes. Source pixel values +/// in `src_pixels` must be valid RGBA `u32` values. +#[inline(always)] +#[allow(clippy::cast_ptr_alignment)] // _mm_storeu/loadu_si128 do not require alignment +pub(super) unsafe fn blend_4px_opaque_sse2(dst_ptr: *mut u8, src_pixels: [u32; 4]) { + use std::arch::x86_64::{ + __m128i, _mm_add_epi16, _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, + _mm_mullo_epi16, _mm_or_si128, _mm_packus_epi16, _mm_set1_epi16, _mm_set1_epi32, + _mm_set_epi32, _mm_setzero_si128, _mm_shufflehi_epi16, _mm_shufflelo_epi16, _mm_srli_epi16, + _mm_storeu_si128, _mm_sub_epi16, _mm_unpackhi_epi8, _mm_unpacklo_epi8, + }; + + let zero = _mm_setzero_si128(); + let c255 = _mm_set1_epi16(255); + let c128 = _mm_set1_epi16(128); + + // Assemble 4 gathered source pixels into one register. + let src4 = _mm_set_epi32( + src_pixels[3].cast_signed(), + src_pixels[2].cast_signed(), + src_pixels[1].cast_signed(), + src_pixels[0].cast_signed(), + ); + + // Mask with 0xFF at each pixel's alpha-byte position (bytes 3,7,11,15). + let alpha_byte_mask = _mm_set1_epi32(0xFF00_0000_u32.cast_signed()); + + // Fast path: all 4 source pixels fully opaque → direct copy. + let alpha_bytes = _mm_and_si128(src4, alpha_byte_mask); + if _mm_movemask_epi8(_mm_cmpeq_epi8(alpha_bytes, alpha_byte_mask)) == 0xFFFF { + _mm_storeu_si128(dst_ptr.cast::<__m128i>(), src4); + return; + } + + // Fast path: all 4 source pixels fully transparent → nothing to do. + if _mm_movemask_epi8(_mm_cmpeq_epi8(alpha_bytes, zero)) == 0xFFFF { + return; + } + + let dst4 = _mm_loadu_si128(dst_ptr.cast::<__m128i>().cast_const()); + + // Replace source alpha channel with 255 for correct composite-alpha + // via blend_u8(255, dst_alpha, src_alpha). + let src_blend = _mm_or_si128(src4, alpha_byte_mask); + + // --- Low 2 pixels (u16 arithmetic) --- + let src_lo = _mm_unpacklo_epi8(src_blend, zero); + let dst_lo = _mm_unpacklo_epi8(dst4, zero); + + // Extract original source alpha and broadcast within each 4-u16 pixel group. + let src_orig_lo = _mm_unpacklo_epi8(src4, zero); + // _MM_SHUFFLE(3,3,3,3) = 0xFF → replicate element 3 (alpha) to all 4 positions. + let alpha_lo = _mm_shufflehi_epi16(_mm_shufflelo_epi16(src_orig_lo, 0xFF), 0xFF); + + let inv_alpha_lo = _mm_sub_epi16(c255, alpha_lo); + let val_lo = _mm_add_epi16( + _mm_add_epi16(_mm_mullo_epi16(src_lo, alpha_lo), _mm_mullo_epi16(dst_lo, inv_alpha_lo)), + c128, + ); + let result_lo = _mm_srli_epi16(_mm_add_epi16(val_lo, _mm_srli_epi16(val_lo, 8)), 8); + + // --- High 2 pixels --- + let src_hi = _mm_unpackhi_epi8(src_blend, zero); + let dst_hi = _mm_unpackhi_epi8(dst4, zero); + let src_orig_hi = _mm_unpackhi_epi8(src4, zero); + let alpha_hi = _mm_shufflehi_epi16(_mm_shufflelo_epi16(src_orig_hi, 0xFF), 0xFF); + + let inv_alpha_hi = _mm_sub_epi16(c255, alpha_hi); + let val_hi = _mm_add_epi16( + _mm_add_epi16(_mm_mullo_epi16(src_hi, alpha_hi), _mm_mullo_epi16(dst_hi, inv_alpha_hi)), + c128, + ); + let result_hi = _mm_srli_epi16(_mm_add_epi16(val_hi, _mm_srli_epi16(val_hi, 8)), 8); + + // Pack back to u8 and store. + _mm_storeu_si128(dst_ptr.cast::<__m128i>(), _mm_packus_epi16(result_lo, result_hi)); +} + +/// Blend 4 gathered source RGBA pixels onto 4 contiguous destination pixels +/// using SSE2 "over" compositing **with** an opacity multiplier applied to +/// each pixel's source alpha. +/// +/// # Safety +/// +/// `dst_ptr` must point to at least 16 writable bytes. +#[inline(always)] +#[allow(clippy::cast_ptr_alignment)] // _mm_storeu/loadu_si128 do not require alignment +pub(super) unsafe fn blend_4px_alpha_sse2(dst_ptr: *mut u8, src_pixels: [u32; 4], opacity: u16) { + use std::arch::x86_64::{ + __m128i, _mm_add_epi16, _mm_loadu_si128, _mm_mullo_epi16, _mm_or_si128, _mm_packus_epi16, + _mm_set1_epi16, _mm_set1_epi32, _mm_set_epi32, _mm_setzero_si128, _mm_shufflehi_epi16, + _mm_shufflelo_epi16, _mm_srli_epi16, _mm_storeu_si128, _mm_sub_epi16, _mm_unpackhi_epi8, + _mm_unpacklo_epi8, + }; + + let zero = _mm_setzero_si128(); + let c255 = _mm_set1_epi16(255); + let c128 = _mm_set1_epi16(128); + let opacity_v = _mm_set1_epi16(opacity.cast_signed()); + + let src4 = _mm_set_epi32( + src_pixels[3].cast_signed(), + src_pixels[2].cast_signed(), + src_pixels[1].cast_signed(), + src_pixels[0].cast_signed(), + ); + + let dst4 = _mm_loadu_si128(dst_ptr.cast::<__m128i>().cast_const()); + let alpha_byte_mask = _mm_set1_epi32(0xFF00_0000_u32.cast_signed()); + let src_blend = _mm_or_si128(src4, alpha_byte_mask); + + // --- Low 2 pixels --- + let src_lo = _mm_unpacklo_epi8(src_blend, zero); + let dst_lo = _mm_unpacklo_epi8(dst4, zero); + + // Extract original alpha, apply opacity: sa_eff = (sa * opacity + 128) >> 8. + // Max value: (255*255+128)>>8 = 254, so no clamping needed. + let src_orig_lo = _mm_unpacklo_epi8(src4, zero); + let raw_alpha_lo = _mm_shufflehi_epi16(_mm_shufflelo_epi16(src_orig_lo, 0xFF), 0xFF); + let alpha_lo = _mm_srli_epi16(_mm_add_epi16(_mm_mullo_epi16(raw_alpha_lo, opacity_v), c128), 8); + + let inv_alpha_lo = _mm_sub_epi16(c255, alpha_lo); + let val_lo = _mm_add_epi16( + _mm_add_epi16(_mm_mullo_epi16(src_lo, alpha_lo), _mm_mullo_epi16(dst_lo, inv_alpha_lo)), + c128, + ); + let result_lo = _mm_srli_epi16(_mm_add_epi16(val_lo, _mm_srli_epi16(val_lo, 8)), 8); + + // --- High 2 pixels --- + let src_hi = _mm_unpackhi_epi8(src_blend, zero); + let dst_hi = _mm_unpackhi_epi8(dst4, zero); + let src_orig_hi = _mm_unpackhi_epi8(src4, zero); + let raw_alpha_hi = _mm_shufflehi_epi16(_mm_shufflelo_epi16(src_orig_hi, 0xFF), 0xFF); + let alpha_hi = _mm_srli_epi16(_mm_add_epi16(_mm_mullo_epi16(raw_alpha_hi, opacity_v), c128), 8); + + let inv_alpha_hi = _mm_sub_epi16(c255, alpha_hi); + let val_hi = _mm_add_epi16( + _mm_add_epi16(_mm_mullo_epi16(src_hi, alpha_hi), _mm_mullo_epi16(dst_hi, inv_alpha_hi)), + c128, + ); + let result_hi = _mm_srli_epi16(_mm_add_epi16(val_hi, _mm_srli_epi16(val_hi, 8)), 8); + + _mm_storeu_si128(dst_ptr.cast::<__m128i>(), _mm_packus_epi16(result_lo, result_hi)); +} + +// ── AVX2 alpha-blend helpers ───────────────────────────────────────────────── +// +// Process 8 RGBA pixels at a time using AVX2 integer arithmetic. +// Same algorithm as the SSE2 helpers above, widened to 256-bit registers. + +/// Blend 8 gathered source RGBA pixels onto 8 contiguous destination pixels +/// using AVX2 "over" compositing (no opacity modifier). +/// +/// # Safety +/// +/// `dst_ptr` must point to at least 32 writable bytes. Source pixel values +/// in `src_pixels` must be valid RGBA `u32` values. +#[target_feature(enable = "avx2")] +#[inline] +#[allow(clippy::cast_ptr_alignment)] +pub(super) unsafe fn blend_8px_opaque_avx2(dst_ptr: *mut u8, src_pixels: [u32; 8]) { + use std::arch::x86_64::{ + __m256i, _mm256_add_epi16, _mm256_and_si256, _mm256_cmpeq_epi8, _mm256_loadu_si256, + _mm256_movemask_epi8, _mm256_mullo_epi16, _mm256_or_si256, _mm256_packus_epi16, + _mm256_set1_epi16, _mm256_set1_epi32, _mm256_set_epi32, _mm256_setzero_si256, + _mm256_shufflehi_epi16, _mm256_shufflelo_epi16, _mm256_srli_epi16, _mm256_storeu_si256, + _mm256_sub_epi16, _mm256_unpackhi_epi8, _mm256_unpacklo_epi8, + }; + + let zero = _mm256_setzero_si256(); + let c255 = _mm256_set1_epi16(255); + let c128 = _mm256_set1_epi16(128); + + // Assemble 8 gathered source pixels into one 256-bit register. + let src8 = _mm256_set_epi32( + src_pixels[7].cast_signed(), + src_pixels[6].cast_signed(), + src_pixels[5].cast_signed(), + src_pixels[4].cast_signed(), + src_pixels[3].cast_signed(), + src_pixels[2].cast_signed(), + src_pixels[1].cast_signed(), + src_pixels[0].cast_signed(), + ); + + let alpha_byte_mask = _mm256_set1_epi32(0xFF00_0000_u32.cast_signed()); + + // Fast path: all 8 source pixels fully opaque → direct copy. + let alpha_bytes = _mm256_and_si256(src8, alpha_byte_mask); + if _mm256_movemask_epi8(_mm256_cmpeq_epi8(alpha_bytes, alpha_byte_mask)) == -1i32 { + _mm256_storeu_si256(dst_ptr.cast::<__m256i>(), src8); + return; + } + + // Fast path: all 8 source pixels fully transparent → nothing to do. + if _mm256_movemask_epi8(_mm256_cmpeq_epi8(alpha_bytes, zero)) == -1i32 { + return; + } + + let dst8 = _mm256_loadu_si256(dst_ptr.cast::<__m256i>().cast_const()); + let src_blend = _mm256_or_si256(src8, alpha_byte_mask); + + // --- Low 4 pixels (within each 128-bit lane) --- + let src_lo = _mm256_unpacklo_epi8(src_blend, zero); + let dst_lo = _mm256_unpacklo_epi8(dst8, zero); + let src_orig_lo = _mm256_unpacklo_epi8(src8, zero); + let alpha_lo = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(src_orig_lo, 0xFF), 0xFF); + + let inv_alpha_lo = _mm256_sub_epi16(c255, alpha_lo); + let val_lo = _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(src_lo, alpha_lo), + _mm256_mullo_epi16(dst_lo, inv_alpha_lo), + ), + c128, + ); + let result_lo = _mm256_srli_epi16(_mm256_add_epi16(val_lo, _mm256_srli_epi16(val_lo, 8)), 8); + + // --- High 4 pixels (within each 128-bit lane) --- + let src_hi = _mm256_unpackhi_epi8(src_blend, zero); + let dst_hi = _mm256_unpackhi_epi8(dst8, zero); + let src_orig_hi = _mm256_unpackhi_epi8(src8, zero); + let alpha_hi = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(src_orig_hi, 0xFF), 0xFF); + + let inv_alpha_hi = _mm256_sub_epi16(c255, alpha_hi); + let val_hi = _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(src_hi, alpha_hi), + _mm256_mullo_epi16(dst_hi, inv_alpha_hi), + ), + c128, + ); + let result_hi = _mm256_srli_epi16(_mm256_add_epi16(val_hi, _mm256_srli_epi16(val_hi, 8)), 8); + + // Pack back to u8. Since result_lo and result_hi both come from + // unpacking the same 256-bit register (src8/dst8), _mm256_packus_epi16 + // already produces correct pixel order [px0..px3 | px4..px7] — no + // cross-lane permute needed. + let packed = _mm256_packus_epi16(result_lo, result_hi); + _mm256_storeu_si256(dst_ptr.cast::<__m256i>(), packed); +} + +/// Blend 8 gathered source RGBA pixels onto 8 contiguous destination pixels +/// using AVX2 "over" compositing **with** an opacity multiplier applied to +/// each pixel's source alpha. +/// +/// # Safety +/// +/// `dst_ptr` must point to at least 32 writable bytes. +#[target_feature(enable = "avx2")] +#[inline] +#[allow(clippy::cast_ptr_alignment)] +pub(super) unsafe fn blend_8px_alpha_avx2(dst_ptr: *mut u8, src_pixels: [u32; 8], opacity: u16) { + use std::arch::x86_64::{ + __m256i, _mm256_add_epi16, _mm256_loadu_si256, _mm256_mullo_epi16, _mm256_or_si256, + _mm256_packus_epi16, _mm256_set1_epi16, _mm256_set1_epi32, _mm256_set_epi32, + _mm256_setzero_si256, _mm256_shufflehi_epi16, _mm256_shufflelo_epi16, _mm256_srli_epi16, + _mm256_storeu_si256, _mm256_sub_epi16, _mm256_unpackhi_epi8, _mm256_unpacklo_epi8, + }; + + let zero = _mm256_setzero_si256(); + let c255 = _mm256_set1_epi16(255); + let c128 = _mm256_set1_epi16(128); + let opacity_v = _mm256_set1_epi16(opacity.cast_signed()); + + let src8 = _mm256_set_epi32( + src_pixels[7].cast_signed(), + src_pixels[6].cast_signed(), + src_pixels[5].cast_signed(), + src_pixels[4].cast_signed(), + src_pixels[3].cast_signed(), + src_pixels[2].cast_signed(), + src_pixels[1].cast_signed(), + src_pixels[0].cast_signed(), + ); + + let dst8 = _mm256_loadu_si256(dst_ptr.cast::<__m256i>().cast_const()); + let alpha_byte_mask = _mm256_set1_epi32(0xFF00_0000_u32.cast_signed()); + let src_blend = _mm256_or_si256(src8, alpha_byte_mask); + + // --- Low 4 pixels --- + let src_lo = _mm256_unpacklo_epi8(src_blend, zero); + let dst_lo = _mm256_unpacklo_epi8(dst8, zero); + let src_orig_lo = _mm256_unpacklo_epi8(src8, zero); + let raw_alpha_lo = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(src_orig_lo, 0xFF), 0xFF); + let alpha_lo = + _mm256_srli_epi16(_mm256_add_epi16(_mm256_mullo_epi16(raw_alpha_lo, opacity_v), c128), 8); + + let inv_alpha_lo = _mm256_sub_epi16(c255, alpha_lo); + let val_lo = _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(src_lo, alpha_lo), + _mm256_mullo_epi16(dst_lo, inv_alpha_lo), + ), + c128, + ); + let result_lo = _mm256_srli_epi16(_mm256_add_epi16(val_lo, _mm256_srli_epi16(val_lo, 8)), 8); + + // --- High 4 pixels --- + let src_hi = _mm256_unpackhi_epi8(src_blend, zero); + let dst_hi = _mm256_unpackhi_epi8(dst8, zero); + let src_orig_hi = _mm256_unpackhi_epi8(src8, zero); + let raw_alpha_hi = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(src_orig_hi, 0xFF), 0xFF); + let alpha_hi = + _mm256_srli_epi16(_mm256_add_epi16(_mm256_mullo_epi16(raw_alpha_hi, opacity_v), c128), 8); + + let inv_alpha_hi = _mm256_sub_epi16(c255, alpha_hi); + let val_hi = _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(src_hi, alpha_hi), + _mm256_mullo_epi16(dst_hi, inv_alpha_hi), + ), + c128, + ); + let result_hi = _mm256_srli_epi16(_mm256_add_epi16(val_hi, _mm256_srli_epi16(val_hi, 8)), 8); + + // Same as opaque variant: pack already yields correct order, no permute. + let packed = _mm256_packus_epi16(result_lo, result_hi); + _mm256_storeu_si256(dst_ptr.cast::<__m256i>(), packed); +} + +// ── SIMD alpha-opaqueness check ───────────────────────────────────────────── + +/// Check if all alpha bytes in an RGBA8 row are 0xFF using SSE2. +/// +/// Processes 4 pixels (16 bytes) per iteration. Returns `true` if every +/// pixel's alpha channel is 255. +/// +/// # Safety +/// +/// Caller must ensure `row` length is a multiple of 4 bytes (always true +/// for valid RGBA8 data). +#[target_feature(enable = "sse2")] +#[inline] +pub(super) unsafe fn all_alpha_opaque_sse2(row: &[u8]) -> bool { + use std::arch::x86_64::{ + _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi32, + }; + + let alpha_mask = _mm_set1_epi32(0xFF00_0000_u32.cast_signed()); + let len = row.len(); + let simd_end = len & !15; // round down to multiple of 16 (4 pixels) + let mut i = 0; + + while i < simd_end { + let chunk = _mm_loadu_si128(row.as_ptr().add(i).cast()); + let alpha_bytes = _mm_and_si128(chunk, alpha_mask); + // Check that all alpha-position bytes equal 0xFF. + // After AND with mask, alpha positions have 0xFF if opaque. + // cmpeq + movemask: if all 16 bytes match, mask == 0xFFFF. + // But we only care about bytes 3,7,11,15 (alpha positions). + // After AND, non-alpha bytes are 0x00; cmpeq with mask will set + // those to 0x00 as well. We want the alpha-position bits of the + // movemask: bits 3,7,11,15 = 0x8888. + if _mm_movemask_epi8(_mm_cmpeq_epi8(alpha_bytes, alpha_mask)) & 0x8888 != 0x8888 { + return false; + } + i += 16; + } + + // Scalar tail. + while i + 3 < len { + if row[i + 3] != 255 { + return false; + } + i += 4; + } + true +} + +/// Check if all alpha bytes in an RGBA8 row are 0xFF using AVX2. +/// +/// Processes 8 pixels (32 bytes) per iteration. +/// +/// # Safety +/// +/// Caller must ensure `row` length is a multiple of 4 bytes. +#[target_feature(enable = "avx2")] +#[inline] +pub(super) unsafe fn all_alpha_opaque_avx2(row: &[u8]) -> bool { + use std::arch::x86_64::{ + _mm256_and_si256, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8, + _mm256_set1_epi32, + }; + + let alpha_mask = _mm256_set1_epi32(0xFF00_0000_u32.cast_signed()); + let len = row.len(); + let simd_end = len & !31; // round down to multiple of 32 (8 pixels) + let mut i = 0; + + while i < simd_end { + let chunk = _mm256_loadu_si256(row.as_ptr().add(i).cast()); + let alpha_bytes = _mm256_and_si256(chunk, alpha_mask); + // Alpha-position bits: bytes 3,7,11,15,19,23,27,31 → mask bits 0x88888888. + let cmp_mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(alpha_bytes, alpha_mask)); + if cmp_mask & 0x8888_8888_u32.cast_signed() != 0x8888_8888_u32.cast_signed() { + return false; + } + i += 32; + } + + // Scalar tail. + while i + 3 < len { + if row[i + 3] != 255 { + return false; + } + i += 4; + } + true +} + +// ── SSE2-compatible i32 multiply helper ───────────────────────────────────── + +/// SSE2-compatible signed 32-bit multiply (low 32 bits of each lane). +/// +/// SSE2 only has `_mm_mul_epu32` which multiplies lanes 0 and 2 as +/// unsigned 32-bit → 64-bit. We use it twice (even + odd lanes) and +/// re-interleave to get all four i32 products. The unsigned multiply +/// gives the correct low-32 result for signed operands (two's complement). +#[target_feature(enable = "sse2")] +#[inline] +pub(super) unsafe fn mul32_sse2( + a: std::arch::x86_64::__m128i, + b: std::arch::x86_64::__m128i, +) -> std::arch::x86_64::__m128i { + use std::arch::x86_64::{_mm_mul_epu32, _mm_shuffle_epi32, _mm_unpacklo_epi32}; + // Multiply even lanes (0, 2) → 64-bit results. + let even = _mm_mul_epu32(a, b); + // Shuffle odd lanes (1, 3) into even positions, then multiply. + let odd = + _mm_mul_epu32(_mm_shuffle_epi32(a, 0b11_11_01_01), _mm_shuffle_epi32(b, 0b11_11_01_01)); + // Extract the low 32 bits of each 64-bit product: + // even = [p0_lo, p0_hi, p2_lo, p2_hi] + // odd = [p1_lo, p1_hi, p3_lo, p3_hi] + // shuffle 0b00_00_10_00 picks dwords 0 and 2 → [p_lo0, p_lo2, ?, ?] + let even_lo = _mm_shuffle_epi32(even, 0b00_00_10_00); + let odd_lo = _mm_shuffle_epi32(odd, 0b00_00_10_00); + // Interleave low halves → [p0, p1, p2, p3]. + _mm_unpacklo_epi32(even_lo, odd_lo) +} + +// ── I420 → RGBA8 SSE2/SSE4.1 (macro-generated) ───────────────────────────── + +/// Generate an I420 → RGBA8 row conversion function for a given SIMD tier. +/// +/// The multiply strategy is injected via `$mul32`: either `mul32_sse2` +/// (7-instruction emulation) or the native `_mm_mullo_epi32` wrapper. +macro_rules! impl_i420_to_rgba8_row { + ($name:ident, $feature:literal, $mul32:expr) => { + #[doc = concat!("Convert up to `width` I420 pixels from one row to RGBA8 using ", $feature, ".")] + /// + /// Returns the number of pixels converted (always a multiple of 4). + /// The caller must handle the remaining `width - returned` tail pixels + /// with the scalar path. + #[target_feature(enable = $feature)] + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] + pub(super) unsafe fn $name( + y_row: &[u8], + u_row: &[u8], + v_row: &[u8], + rgba_out: &mut [u8], + width: usize, + ) -> usize { + use std::arch::x86_64::{ + _mm_add_epi32, _mm_or_si128, _mm_packs_epi32, _mm_packus_epi16, _mm_set1_epi32, + _mm_set1_epi8, _mm_set_epi32, _mm_setzero_si128, _mm_srai_epi32, _mm_storeu_si128, + _mm_sub_epi32, _mm_unpacklo_epi16, _mm_unpacklo_epi8, + }; + + let simd_width = width & !3; + if simd_width == 0 { + return 0; + } + + let coeff_298 = _mm_set1_epi32(298); + let coeff_409 = _mm_set1_epi32(409); + let coeff_n100 = _mm_set1_epi32(-100); + let coeff_n208 = _mm_set1_epi32(-208); + let coeff_516 = _mm_set1_epi32(516); + let bias_16 = _mm_set1_epi32(16); + let bias_128 = _mm_set1_epi32(128); + let rounding = _mm_set1_epi32(128); + let alpha_mask = _mm_set1_epi32(0xFF00_0000_u32.cast_signed()); + let zero = _mm_setzero_si128(); + + let mul32 = $mul32; + + let mut col = 0usize; + while col < simd_width { + let y0 = i32::from(y_row[col]); + let y1 = i32::from(y_row[col + 1]); + let y2 = i32::from(y_row[col + 2]); + let y3 = i32::from(y_row[col + 3]); + let y32 = _mm_set_epi32(y3, y2, y1, y0); + + let chroma_col = col / 2; + let u0 = i32::from(u_row[chroma_col]); + let u1 = i32::from(u_row[chroma_col + 1]); + let v0 = i32::from(v_row[chroma_col]); + let v1 = i32::from(v_row[chroma_col + 1]); + let u32x4 = _mm_set_epi32(u1, u1, u0, u0); + let v32x4 = _mm_set_epi32(v1, v1, v0, v0); + + let c = _mm_sub_epi32(y32, bias_16); + let d = _mm_sub_epi32(u32x4, bias_128); + let e = _mm_sub_epi32(v32x4, bias_128); + + let r32 = _mm_srai_epi32( + _mm_add_epi32( + _mm_add_epi32(mul32(coeff_298, c), mul32(coeff_409, e)), + rounding, + ), + 8, + ); + + let g32 = _mm_srai_epi32( + _mm_add_epi32( + _mm_add_epi32( + _mm_add_epi32(mul32(coeff_298, c), mul32(coeff_n100, d)), + mul32(coeff_n208, e), + ), + rounding, + ), + 8, + ); + + let b32 = _mm_srai_epi32( + _mm_add_epi32( + _mm_add_epi32(mul32(coeff_298, c), mul32(coeff_516, d)), + rounding, + ), + 8, + ); + + let r16 = _mm_packs_epi32(r32, zero); + let g16 = _mm_packs_epi32(g32, zero); + let b16 = _mm_packs_epi32(b32, zero); + let r8 = _mm_packus_epi16(r16, zero); + let g8 = _mm_packus_epi16(g16, zero); + let b8 = _mm_packus_epi16(b16, zero); + + let rg = _mm_unpacklo_epi8(r8, g8); + let ba = _mm_unpacklo_epi8(b8, _mm_set1_epi8(-1)); + let rgba = _mm_unpacklo_epi16(rg, ba); + let rgba = _mm_or_si128(rgba, alpha_mask); + + let out_ptr = rgba_out.as_mut_ptr().add(col * 4); + _mm_storeu_si128(out_ptr.cast(), rgba); + + col += 4; + } + simd_width + } + }; +} + +// SSE2 wrapper: calls mul32_sse2 (7-instruction emulation). +impl_i420_to_rgba8_row!(i420_to_rgba8_row_sse2, "sse2", |a, b| mul32_sse2(a, b)); + +// SSE4.1 wrapper: uses native _mm_mullo_epi32. +impl_i420_to_rgba8_row!(i420_to_rgba8_row_sse41, "sse4.1", |a, b| { + std::arch::x86_64::_mm_mullo_epi32(a, b) +}); + +// ── NV12 → RGBA8 SSE2/SSE4.1 (macro-generated) ───────────────────────────── + +/// Generate an NV12 → RGBA8 row conversion function for a given SIMD tier. +macro_rules! impl_nv12_to_rgba8_row { + ($name:ident, $feature:literal, $mul32:expr) => { + #[doc = concat!("Convert up to `width` NV12 pixels from one row to RGBA8 using ", $feature, ".")] + /// + /// Returns the number of pixels converted (always a multiple of 4). + #[target_feature(enable = $feature)] + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] + pub(super) unsafe fn $name( + y_row: &[u8], + uv_row: &[u8], + rgba_out: &mut [u8], + width: usize, + ) -> usize { + use std::arch::x86_64::{ + _mm_add_epi32, _mm_or_si128, _mm_packs_epi32, _mm_packus_epi16, _mm_set1_epi32, + _mm_set1_epi8, _mm_set_epi32, _mm_setzero_si128, _mm_srai_epi32, _mm_storeu_si128, + _mm_sub_epi32, _mm_unpacklo_epi16, _mm_unpacklo_epi8, + }; + + let simd_width = width & !3; + if simd_width == 0 { + return 0; + } + + let coeff_298 = _mm_set1_epi32(298); + let coeff_409 = _mm_set1_epi32(409); + let coeff_n100 = _mm_set1_epi32(-100); + let coeff_n208 = _mm_set1_epi32(-208); + let coeff_516 = _mm_set1_epi32(516); + let bias_16 = _mm_set1_epi32(16); + let bias_128 = _mm_set1_epi32(128); + let rounding = _mm_set1_epi32(128); + let alpha_mask = _mm_set1_epi32(0xFF00_0000_u32.cast_signed()); + let zero = _mm_setzero_si128(); + + let mul32 = $mul32; + + let mut col = 0usize; + while col < simd_width { + let y0 = i32::from(y_row[col]); + let y1 = i32::from(y_row[col + 1]); + let y2 = i32::from(y_row[col + 2]); + let y3 = i32::from(y_row[col + 3]); + let y32 = _mm_set_epi32(y3, y2, y1, y0); + + let chroma_byte = (col / 2) * 2; + let u0 = i32::from(uv_row[chroma_byte]); + let v0 = i32::from(uv_row[chroma_byte + 1]); + let u1 = i32::from(uv_row[chroma_byte + 2]); + let v1 = i32::from(uv_row[chroma_byte + 3]); + let u32x4 = _mm_set_epi32(u1, u1, u0, u0); + let v32x4 = _mm_set_epi32(v1, v1, v0, v0); + + let c = _mm_sub_epi32(y32, bias_16); + let d = _mm_sub_epi32(u32x4, bias_128); + let e = _mm_sub_epi32(v32x4, bias_128); + + let r32 = _mm_srai_epi32( + _mm_add_epi32( + _mm_add_epi32(mul32(coeff_298, c), mul32(coeff_409, e)), + rounding, + ), + 8, + ); + + let g32 = _mm_srai_epi32( + _mm_add_epi32( + _mm_add_epi32( + _mm_add_epi32(mul32(coeff_298, c), mul32(coeff_n100, d)), + mul32(coeff_n208, e), + ), + rounding, + ), + 8, + ); + + let b32 = _mm_srai_epi32( + _mm_add_epi32( + _mm_add_epi32(mul32(coeff_298, c), mul32(coeff_516, d)), + rounding, + ), + 8, + ); + + let r16 = _mm_packs_epi32(r32, zero); + let g16 = _mm_packs_epi32(g32, zero); + let b16 = _mm_packs_epi32(b32, zero); + let r8 = _mm_packus_epi16(r16, zero); + let g8 = _mm_packus_epi16(g16, zero); + let b8 = _mm_packus_epi16(b16, zero); + + let rg = _mm_unpacklo_epi8(r8, g8); + let ba = _mm_unpacklo_epi8(b8, _mm_set1_epi8(-1)); + let rgba = _mm_unpacklo_epi16(rg, ba); + let rgba = _mm_or_si128(rgba, alpha_mask); + + let out_ptr = rgba_out.as_mut_ptr().add(col * 4); + _mm_storeu_si128(out_ptr.cast(), rgba); + + col += 4; + } + simd_width + } + }; +} + +impl_nv12_to_rgba8_row!(nv12_to_rgba8_row_sse2, "sse2", |a, b| mul32_sse2(a, b)); +impl_nv12_to_rgba8_row!(nv12_to_rgba8_row_sse41, "sse4.1", |a, b| { + std::arch::x86_64::_mm_mullo_epi32(a, b) +}); + +// ── NV12 → RGBA8 AVX2 (8 pixels / iter) ──────────────────────────────────── + +/// Convert up to `width` NV12 pixels from one row to RGBA8 using AVX2. +/// +/// Processes 8 pixels per iteration (256-bit registers) — double the +/// throughput of the SSE4.1 variant. The heavy i32 multiplies run in +/// 256-bit lanes while the final u8 pack + RGBA interleave drops to +/// 128-bit SSE to avoid AVX2 lane-crossing headaches. +/// +/// Returns the number of pixels converted (always a multiple of 8). +#[target_feature(enable = "avx2")] +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] +pub(super) unsafe fn nv12_to_rgba8_row_avx2( + y_row: &[u8], + uv_row: &[u8], + rgba_out: &mut [u8], + width: usize, +) -> usize { + use std::arch::x86_64::{ + _mm256_add_epi32, _mm256_castsi256_si128, _mm256_cvtepu8_epi32, _mm256_extracti128_si256, + _mm256_mullo_epi32, _mm256_set1_epi32, _mm256_srai_epi32, _mm256_sub_epi32, + _mm_loadl_epi64, _mm_or_si128, _mm_packs_epi32, _mm_packus_epi16, _mm_set1_epi32, + _mm_set1_epi8, _mm_set_epi8, _mm_setzero_si128, _mm_shuffle_epi8, _mm_storeu_si128, + _mm_unpacklo_epi16, _mm_unpacklo_epi8, + }; + + let simd_width = width & !7; // round down to multiple of 8 + if simd_width == 0 { + return 0; + } + + let coeff_298 = _mm256_set1_epi32(298); + let coeff_409 = _mm256_set1_epi32(409); + let coeff_n100 = _mm256_set1_epi32(-100); + let coeff_n208 = _mm256_set1_epi32(-208); + let coeff_516 = _mm256_set1_epi32(516); + let bias_16 = _mm256_set1_epi32(16); + let bias_128 = _mm256_set1_epi32(128); + let rounding = _mm256_set1_epi32(128); + let alpha_mask = _mm_set1_epi32(0xFF00_0000_u32.cast_signed()); + let zero = _mm_setzero_si128(); + + // Shuffle controls for deinterleaving + duplicating NV12 UV pairs. + let u_shuf = _mm_set_epi8(-1, -1, -1, -1, -1, -1, -1, -1, 6, 6, 4, 4, 2, 2, 0, 0); + let v_shuf = _mm_set_epi8(-1, -1, -1, -1, -1, -1, -1, -1, 7, 7, 5, 5, 3, 3, 1, 1); + + let mut col = 0usize; + while col < simd_width { + let y8 = _mm_loadl_epi64(y_row.as_ptr().add(col).cast()); + let y32 = _mm256_cvtepu8_epi32(y8); + + let chroma_byte = (col / 2) * 2; + let uv8 = _mm_loadl_epi64(uv_row.as_ptr().add(chroma_byte).cast()); + let u32x8 = _mm256_cvtepu8_epi32(_mm_shuffle_epi8(uv8, u_shuf)); + let v32x8 = _mm256_cvtepu8_epi32(_mm_shuffle_epi8(uv8, v_shuf)); + + let c = _mm256_sub_epi32(y32, bias_16); + let d = _mm256_sub_epi32(u32x8, bias_128); + let e = _mm256_sub_epi32(v32x8, bias_128); + + let r32 = _mm256_srai_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_mullo_epi32(coeff_298, c), + _mm256_mullo_epi32(coeff_409, e), + ), + rounding, + ), + 8, + ); + + let g32 = _mm256_srai_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_mullo_epi32(coeff_298, c), + _mm256_mullo_epi32(coeff_n100, d), + ), + _mm256_mullo_epi32(coeff_n208, e), + ), + rounding, + ), + 8, + ); + + let b32 = _mm256_srai_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_mullo_epi32(coeff_298, c), + _mm256_mullo_epi32(coeff_516, d), + ), + rounding, + ), + 8, + ); + + // ── Pack + interleave: split into two 4-pixel halves ────── + let r_lo = _mm256_castsi256_si128(r32); + let r_hi = _mm256_extracti128_si256(r32, 1); + let g_lo = _mm256_castsi256_si128(g32); + let g_hi = _mm256_extracti128_si256(g32, 1); + let b_lo = _mm256_castsi256_si128(b32); + let b_hi = _mm256_extracti128_si256(b32, 1); + + // Pixels 0–3 + let r16 = _mm_packs_epi32(r_lo, zero); + let g16 = _mm_packs_epi32(g_lo, zero); + let b16 = _mm_packs_epi32(b_lo, zero); + let r8 = _mm_packus_epi16(r16, zero); + let g8 = _mm_packus_epi16(g16, zero); + let b8 = _mm_packus_epi16(b16, zero); + + let rg = _mm_unpacklo_epi8(r8, g8); + let ba = _mm_unpacklo_epi8(b8, _mm_set1_epi8(-1)); + let rgba = _mm_unpacklo_epi16(rg, ba); + let rgba = _mm_or_si128(rgba, alpha_mask); + _mm_storeu_si128(rgba_out.as_mut_ptr().add(col * 4).cast(), rgba); + + // Pixels 4–7 + let r16 = _mm_packs_epi32(r_hi, zero); + let g16 = _mm_packs_epi32(g_hi, zero); + let b16 = _mm_packs_epi32(b_hi, zero); + let r8 = _mm_packus_epi16(r16, zero); + let g8 = _mm_packus_epi16(g16, zero); + let b8 = _mm_packus_epi16(b16, zero); + + let rg = _mm_unpacklo_epi8(r8, g8); + let ba = _mm_unpacklo_epi8(b8, _mm_set1_epi8(-1)); + let rgba = _mm_unpacklo_epi16(rg, ba); + let rgba = _mm_or_si128(rgba, alpha_mask); + _mm_storeu_si128(rgba_out.as_mut_ptr().add((col + 4) * 4).cast(), rgba); + + col += 8; + } + simd_width +} + +// ── I420 → RGBA8 AVX2 (8 pixels / iter) ───────────────────────────────────── + +/// Convert up to `width` I420 pixels from one row to RGBA8 using AVX2. +/// +/// Processes 8 pixels per iteration (256-bit registers) — double the +/// throughput of the SSE4.1 variant. Same BT.601 math as the NV12 AVX2 +/// variant, but reads U and V from separate planes instead of interleaved. +/// +/// Returns the number of pixels converted (always a multiple of 8). +#[target_feature(enable = "avx2")] +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] +pub(super) unsafe fn i420_to_rgba8_row_avx2( + y_row: &[u8], + u_row: &[u8], + v_row: &[u8], + rgba_out: &mut [u8], + width: usize, +) -> usize { + use std::arch::x86_64::{ + _mm256_add_epi32, _mm256_castsi256_si128, _mm256_cvtepu8_epi32, _mm256_extracti128_si256, + _mm256_mullo_epi32, _mm256_set1_epi32, _mm256_set_epi32, _mm256_srai_epi32, + _mm256_sub_epi32, _mm_loadl_epi64, _mm_or_si128, _mm_packs_epi32, _mm_packus_epi16, + _mm_set1_epi32, _mm_set1_epi8, _mm_setzero_si128, _mm_storeu_si128, _mm_unpacklo_epi16, + _mm_unpacklo_epi8, + }; + + let simd_width = width & !7; // round down to multiple of 8 + if simd_width == 0 { + return 0; + } + + let coeff_298 = _mm256_set1_epi32(298); + let coeff_409 = _mm256_set1_epi32(409); + let coeff_n100 = _mm256_set1_epi32(-100); + let coeff_n208 = _mm256_set1_epi32(-208); + let coeff_516 = _mm256_set1_epi32(516); + let bias_16 = _mm256_set1_epi32(16); + let bias_128 = _mm256_set1_epi32(128); + let rounding = _mm256_set1_epi32(128); + let alpha_mask = _mm_set1_epi32(0xFF00_0000_u32.cast_signed()); + let zero = _mm_setzero_si128(); + + let mut col = 0usize; + while col < simd_width { + // Load 8 luma samples. + let y8 = _mm_loadl_epi64(y_row.as_ptr().add(col).cast()); + let y32 = _mm256_cvtepu8_epi32(y8); + + // Load 4 U and 4 V chroma samples via scalar reads, duplicating each + // to pair with 2 luma pixels. We avoid `_mm_loadl_epi64` here because + // the chroma planes may have only 4 bytes remaining at the last + // iteration, and `_mm_loadl_epi64` always reads 8 bytes. + let chroma_col = col / 2; + let u0 = i32::from(u_row[chroma_col]); + let u1 = i32::from(u_row[chroma_col + 1]); + let u2 = i32::from(u_row[chroma_col + 2]); + let u3 = i32::from(u_row[chroma_col + 3]); + let u32x8 = _mm256_set_epi32(u3, u3, u2, u2, u1, u1, u0, u0); + let v0 = i32::from(v_row[chroma_col]); + let v1 = i32::from(v_row[chroma_col + 1]); + let v2 = i32::from(v_row[chroma_col + 2]); + let v3 = i32::from(v_row[chroma_col + 3]); + let v32x8 = _mm256_set_epi32(v3, v3, v2, v2, v1, v1, v0, v0); + + let c = _mm256_sub_epi32(y32, bias_16); + let d = _mm256_sub_epi32(u32x8, bias_128); + let e = _mm256_sub_epi32(v32x8, bias_128); + + let r32 = _mm256_srai_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_mullo_epi32(coeff_298, c), + _mm256_mullo_epi32(coeff_409, e), + ), + rounding, + ), + 8, + ); + + let g32 = _mm256_srai_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_mullo_epi32(coeff_298, c), + _mm256_mullo_epi32(coeff_n100, d), + ), + _mm256_mullo_epi32(coeff_n208, e), + ), + rounding, + ), + 8, + ); + + let b32 = _mm256_srai_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_mullo_epi32(coeff_298, c), + _mm256_mullo_epi32(coeff_516, d), + ), + rounding, + ), + 8, + ); + + // ── Pack + interleave: split into two 4-pixel halves ────── + let r_lo = _mm256_castsi256_si128(r32); + let r_hi = _mm256_extracti128_si256(r32, 1); + let g_lo = _mm256_castsi256_si128(g32); + let g_hi = _mm256_extracti128_si256(g32, 1); + let b_lo = _mm256_castsi256_si128(b32); + let b_hi = _mm256_extracti128_si256(b32, 1); + + // Pixels 0–3 + let r16 = _mm_packs_epi32(r_lo, zero); + let g16 = _mm_packs_epi32(g_lo, zero); + let b16 = _mm_packs_epi32(b_lo, zero); + let r8 = _mm_packus_epi16(r16, zero); + let g8 = _mm_packus_epi16(g16, zero); + let b8 = _mm_packus_epi16(b16, zero); + + let rg = _mm_unpacklo_epi8(r8, g8); + let ba = _mm_unpacklo_epi8(b8, _mm_set1_epi8(-1)); + let rgba = _mm_unpacklo_epi16(rg, ba); + let rgba = _mm_or_si128(rgba, alpha_mask); + _mm_storeu_si128(rgba_out.as_mut_ptr().add(col * 4).cast(), rgba); + + // Pixels 4–7 + let r16 = _mm_packs_epi32(r_hi, zero); + let g16 = _mm_packs_epi32(g_hi, zero); + let b16 = _mm_packs_epi32(b_hi, zero); + let r8 = _mm_packus_epi16(r16, zero); + let g8 = _mm_packus_epi16(g16, zero); + let b8 = _mm_packus_epi16(b16, zero); + + let rg = _mm_unpacklo_epi8(r8, g8); + let ba = _mm_unpacklo_epi8(b8, _mm_set1_epi8(-1)); + let rgba = _mm_unpacklo_epi16(rg, ba); + let rgba = _mm_or_si128(rgba, alpha_mask); + _mm_storeu_si128(rgba_out.as_mut_ptr().add((col + 4) * 4).cast(), rgba); + + col += 8; + } + simd_width +} + +// ── RGBA8 → Y-plane SSE2/SSE4.1 (macro-generated) ────────────────────────── + +/// Generate an RGBA8 → Y-plane row conversion function for a given SIMD tier. +macro_rules! impl_rgba8_to_y_row { + ($name:ident, $feature:literal, $mul32:expr) => { + #[doc = concat!("Convert one row of RGBA8 pixels to Y values using ", $feature, ".")] + /// + /// Returns the number of pixels converted (multiple of 4). + #[target_feature(enable = $feature)] + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + pub(super) unsafe fn $name(rgba_row: &[u8], y_out: &mut [u8], width: usize) -> usize { + use std::arch::x86_64::{ + _mm_add_epi32, _mm_and_si128, _mm_loadu_si128, _mm_packs_epi32, _mm_packus_epi16, + _mm_set1_epi32, _mm_setzero_si128, _mm_srai_epi32, _mm_srli_epi32, _mm_storeu_si32, + }; + + let simd_width = width & !3; + if simd_width == 0 { + return 0; + } + + let coeff_66 = _mm_set1_epi32(66); + let coeff_129 = _mm_set1_epi32(129); + let coeff_25 = _mm_set1_epi32(25); + let rounding = _mm_set1_epi32(128); + let bias_16 = _mm_set1_epi32(16); + let zero = _mm_setzero_si128(); + let channel_mask = _mm_set1_epi32(0xFF); + + let mul32 = $mul32; + + let mut col = 0usize; + while col < simd_width { + let src_ptr = rgba_row.as_ptr().add(col * 4); + let px = _mm_loadu_si128(src_ptr.cast()); + + let r = _mm_and_si128(px, channel_mask); + let g = _mm_and_si128(_mm_srli_epi32(px, 8), channel_mask); + let b = _mm_and_si128(_mm_srli_epi32(px, 16), channel_mask); + + let y32 = _mm_add_epi32( + _mm_srai_epi32( + _mm_add_epi32( + _mm_add_epi32( + _mm_add_epi32(mul32(coeff_66, r), mul32(coeff_129, g)), + mul32(coeff_25, b), + ), + rounding, + ), + 8, + ), + bias_16, + ); + + let y16 = _mm_packs_epi32(y32, zero); + let y8 = _mm_packus_epi16(y16, zero); + _mm_storeu_si32(y_out.as_mut_ptr().add(col).cast(), y8); + + col += 4; + } + simd_width + } + }; +} + +impl_rgba8_to_y_row!(rgba8_to_y_row_sse2, "sse2", |a, b| mul32_sse2(a, b)); +impl_rgba8_to_y_row!(rgba8_to_y_row_sse41, "sse4.1", |a, b| { + std::arch::x86_64::_mm_mullo_epi32(a, b) +}); + +// ── RGBA8 → Y-plane AVX2 (8 pixels / iter) ───────────────────────────────── + +/// Convert one row of RGBA8 pixels to Y values using AVX2. +/// +/// Processes 8 pixels per iteration (256-bit registers) — double the +/// throughput of the SSE4.1 variant. +/// +/// Returns the number of pixels converted (multiple of 8). +#[target_feature(enable = "avx2")] +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub(super) unsafe fn rgba8_to_y_row_avx2(rgba_row: &[u8], y_out: &mut [u8], width: usize) -> usize { + use std::arch::x86_64::{ + _mm256_add_epi32, _mm256_and_si256, _mm256_castsi256_si128, _mm256_extracti128_si256, + _mm256_loadu_si256, _mm256_mullo_epi32, _mm256_packs_epi32, _mm256_packus_epi16, + _mm256_set1_epi32, _mm256_setzero_si256, _mm256_srai_epi32, _mm256_srli_epi32, + _mm_storel_epi64, _mm_unpacklo_epi32, + }; + + let simd_width = width & !7; + if simd_width == 0 { + return 0; + } + + let coeff_66 = _mm256_set1_epi32(66); + let coeff_129 = _mm256_set1_epi32(129); + let coeff_25 = _mm256_set1_epi32(25); + let rounding = _mm256_set1_epi32(128); + let bias_16 = _mm256_set1_epi32(16); + let zero = _mm256_setzero_si256(); + let channel_mask = _mm256_set1_epi32(0xFF); + + let mut col = 0usize; + while col < simd_width { + let src_ptr = rgba_row.as_ptr().add(col * 4); + let px = _mm256_loadu_si256(src_ptr.cast()); + + let r = _mm256_and_si256(px, channel_mask); + let g = _mm256_and_si256(_mm256_srli_epi32(px, 8), channel_mask); + let b = _mm256_and_si256(_mm256_srli_epi32(px, 16), channel_mask); + + let y32 = _mm256_add_epi32( + _mm256_srai_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_add_epi32( + _mm256_mullo_epi32(coeff_66, r), + _mm256_mullo_epi32(coeff_129, g), + ), + _mm256_mullo_epi32(coeff_25, b), + ), + rounding, + ), + 8, + ), + bias_16, + ); + + let y16 = _mm256_packs_epi32(y32, zero); + let y8 = _mm256_packus_epi16(y16, zero); + let lo = _mm256_castsi256_si128(y8); + let hi = _mm256_extracti128_si256(y8, 1); + let combined = _mm_unpacklo_epi32(lo, hi); + _mm_storel_epi64(y_out.as_mut_ptr().add(col).cast(), combined); + + col += 8; + } + simd_width +} + +// ── RGBA8 → I420 chroma row (SSE2: 4 chroma samples / iter) ──────────────── + +/// Convert one pair of RGBA8 rows to U and V chroma samples using SSE2. +/// +/// Returns the number of chroma samples converted (multiple of 4). +#[target_feature(enable = "sse2")] +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] +pub(super) unsafe fn rgba8_to_chroma_row_sse2( + rgba_row0: &[u8], + rgba_row1: &[u8], + u_out: &mut [u8], + v_out: &mut [u8], + chroma_width: usize, + luma_width: usize, +) -> usize { + use std::arch::x86_64::{ + _mm_add_epi16, _mm_add_epi32, _mm_and_si128, _mm_loadu_si128, _mm_mullo_epi16, + _mm_packs_epi32, _mm_packus_epi16, _mm_set1_epi16, _mm_set1_epi32, _mm_set_epi16, + _mm_setzero_si128, _mm_srai_epi16, _mm_srli_epi32, _mm_storeu_si32, + }; + + let simd_width = chroma_width & !3; + if simd_width == 0 || luma_width < 8 { + return 0; + } + + let coeff_cb_r = _mm_set1_epi16(-38); + let coeff_cb_g = _mm_set1_epi16(-74); + let coeff_cb_b = _mm_set1_epi16(112); + let coeff_cr_r = _mm_set1_epi16(112); + let coeff_cr_g = _mm_set1_epi16(-94); + let coeff_cr_b = _mm_set1_epi16(-18); + let rounding = _mm_set1_epi16(128); + let bias_128 = _mm_set1_epi16(128); + let zero = _mm_setzero_si128(); + let channel_mask = _mm_set1_epi32(0xFF); + + let mut ccol = 0usize; + while ccol < simd_width { + let luma_col = ccol * 2; + if luma_col + 8 > luma_width { + break; + } + + let ptr0 = rgba_row0.as_ptr().add(luma_col * 4); + let ptr1 = rgba_row1.as_ptr().add(luma_col * 4); + let px0_lo = _mm_loadu_si128(ptr0.cast()); + let px0_hi = _mm_loadu_si128(ptr0.add(16).cast()); + let px1_lo = _mm_loadu_si128(ptr1.cast()); + let px1_hi = _mm_loadu_si128(ptr1.add(16).cast()); + + // 2×2 average for R channel. + let r0_lo = _mm_and_si128(px0_lo, channel_mask); + let r0_hi = _mm_and_si128(px0_hi, channel_mask); + let r1_lo = _mm_and_si128(px1_lo, channel_mask); + let r1_hi = _mm_and_si128(px1_hi, channel_mask); + let r_v_lo = _mm_add_epi32(r0_lo, r1_lo); + let r_v_hi = _mm_add_epi32(r0_hi, r1_hi); + let r_v = _mm_packs_epi32(r_v_lo, r_v_hi); + let r_even = _mm_and_si128(r_v, _mm_set_epi16(0, -1, 0, -1, 0, -1, 0, -1)); + let r_odd = _mm_srli_epi32(r_v, 16); + let r_sum = _mm_add_epi16(r_even, r_odd); + let r_avg = + _mm_srai_epi16(_mm_add_epi16(_mm_packs_epi32(r_sum, zero), _mm_set1_epi16(2)), 2); + + // G channel. + let g0_lo = _mm_and_si128(_mm_srli_epi32(px0_lo, 8), channel_mask); + let g0_hi = _mm_and_si128(_mm_srli_epi32(px0_hi, 8), channel_mask); + let g1_lo = _mm_and_si128(_mm_srli_epi32(px1_lo, 8), channel_mask); + let g1_hi = _mm_and_si128(_mm_srli_epi32(px1_hi, 8), channel_mask); + let g_v_lo = _mm_add_epi32(g0_lo, g1_lo); + let g_v_hi = _mm_add_epi32(g0_hi, g1_hi); + let g_v = _mm_packs_epi32(g_v_lo, g_v_hi); + let g_even = _mm_and_si128(g_v, _mm_set_epi16(0, -1, 0, -1, 0, -1, 0, -1)); + let g_odd = _mm_srli_epi32(g_v, 16); + let g_sum = _mm_add_epi16(g_even, g_odd); + let g_avg = + _mm_srai_epi16(_mm_add_epi16(_mm_packs_epi32(g_sum, zero), _mm_set1_epi16(2)), 2); + + // B channel. + let b0_lo = _mm_and_si128(_mm_srli_epi32(px0_lo, 16), channel_mask); + let b0_hi = _mm_and_si128(_mm_srli_epi32(px0_hi, 16), channel_mask); + let b1_lo = _mm_and_si128(_mm_srli_epi32(px1_lo, 16), channel_mask); + let b1_hi = _mm_and_si128(_mm_srli_epi32(px1_hi, 16), channel_mask); + let b_v_lo = _mm_add_epi32(b0_lo, b1_lo); + let b_v_hi = _mm_add_epi32(b0_hi, b1_hi); + let b_v = _mm_packs_epi32(b_v_lo, b_v_hi); + let b_even = _mm_and_si128(b_v, _mm_set_epi16(0, -1, 0, -1, 0, -1, 0, -1)); + let b_odd = _mm_srli_epi32(b_v, 16); + let b_sum = _mm_add_epi16(b_even, b_odd); + let b_avg = + _mm_srai_epi16(_mm_add_epi16(_mm_packs_epi32(b_sum, zero), _mm_set1_epi16(2)), 2); + + // Cb / Cr coefficient multiplies. + let cb_result = _mm_add_epi16( + _mm_srai_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_mullo_epi16(coeff_cb_r, r_avg), + _mm_mullo_epi16(coeff_cb_g, g_avg), + ), + _mm_mullo_epi16(coeff_cb_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + let cr_result = _mm_add_epi16( + _mm_srai_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_mullo_epi16(coeff_cr_r, r_avg), + _mm_mullo_epi16(coeff_cr_g, g_avg), + ), + _mm_mullo_epi16(coeff_cr_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + let cb_packed = _mm_packus_epi16(cb_result, zero); + let cr_packed = _mm_packus_epi16(cr_result, zero); + + _mm_storeu_si32(u_out.as_mut_ptr().add(ccol).cast(), cb_packed); + _mm_storeu_si32(v_out.as_mut_ptr().add(ccol).cast(), cr_packed); + + ccol += 4; + } + ccol +} + +// ── RGBA8 → NV12 chroma row (SSE2: 4 interleaved UV pairs / iter) ────────── + +/// Convert one pair of RGBA8 rows to interleaved `[U, V, U, V, …]` +/// chroma samples for NV12 output, using SSE2. +/// +/// Returns the number of chroma *pairs* converted (multiple of 4). +#[target_feature(enable = "sse2")] +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] +pub(super) unsafe fn rgba8_to_chroma_row_nv12_sse2( + rgba_row0: &[u8], + rgba_row1: &[u8], + uv_out: &mut [u8], + chroma_width: usize, + luma_width: usize, +) -> usize { + use std::arch::x86_64::{ + _mm_add_epi16, _mm_add_epi32, _mm_and_si128, _mm_loadu_si128, _mm_mullo_epi16, + _mm_packs_epi32, _mm_packus_epi16, _mm_set1_epi16, _mm_set1_epi32, _mm_set_epi16, + _mm_setzero_si128, _mm_srai_epi16, _mm_srli_epi32, _mm_storel_epi64, _mm_unpacklo_epi8, + }; + + let simd_width = chroma_width & !3; + if simd_width == 0 || luma_width < 8 { + return 0; + } + + let coeff_cb_r = _mm_set1_epi16(-38); + let coeff_cb_g = _mm_set1_epi16(-74); + let coeff_cb_b = _mm_set1_epi16(112); + let coeff_cr_r = _mm_set1_epi16(112); + let coeff_cr_g = _mm_set1_epi16(-94); + let coeff_cr_b = _mm_set1_epi16(-18); + let rounding = _mm_set1_epi16(128); + let bias_128 = _mm_set1_epi16(128); + let zero = _mm_setzero_si128(); + let channel_mask = _mm_set1_epi32(0xFF); + + let mut ccol = 0usize; + while ccol < simd_width { + let luma_col = ccol * 2; + if luma_col + 8 > luma_width { + break; + } + + let ptr0 = rgba_row0.as_ptr().add(luma_col * 4); + let ptr1 = rgba_row1.as_ptr().add(luma_col * 4); + let px0_lo = _mm_loadu_si128(ptr0.cast()); + let px0_hi = _mm_loadu_si128(ptr0.add(16).cast()); + let px1_lo = _mm_loadu_si128(ptr1.cast()); + let px1_hi = _mm_loadu_si128(ptr1.add(16).cast()); + + // 2×2 average for R channel. + let r0_lo = _mm_and_si128(px0_lo, channel_mask); + let r0_hi = _mm_and_si128(px0_hi, channel_mask); + let r1_lo = _mm_and_si128(px1_lo, channel_mask); + let r1_hi = _mm_and_si128(px1_hi, channel_mask); + let r_v_lo = _mm_add_epi32(r0_lo, r1_lo); + let r_v_hi = _mm_add_epi32(r0_hi, r1_hi); + let r_v = _mm_packs_epi32(r_v_lo, r_v_hi); + let r_even = _mm_and_si128(r_v, _mm_set_epi16(0, -1, 0, -1, 0, -1, 0, -1)); + let r_odd = _mm_srli_epi32(r_v, 16); + let r_sum = _mm_add_epi16(r_even, r_odd); + let r_avg = + _mm_srai_epi16(_mm_add_epi16(_mm_packs_epi32(r_sum, zero), _mm_set1_epi16(2)), 2); + + // G channel. + let g0_lo = _mm_and_si128(_mm_srli_epi32(px0_lo, 8), channel_mask); + let g0_hi = _mm_and_si128(_mm_srli_epi32(px0_hi, 8), channel_mask); + let g1_lo = _mm_and_si128(_mm_srli_epi32(px1_lo, 8), channel_mask); + let g1_hi = _mm_and_si128(_mm_srli_epi32(px1_hi, 8), channel_mask); + let g_v_lo = _mm_add_epi32(g0_lo, g1_lo); + let g_v_hi = _mm_add_epi32(g0_hi, g1_hi); + let g_v = _mm_packs_epi32(g_v_lo, g_v_hi); + let g_even = _mm_and_si128(g_v, _mm_set_epi16(0, -1, 0, -1, 0, -1, 0, -1)); + let g_odd = _mm_srli_epi32(g_v, 16); + let g_sum = _mm_add_epi16(g_even, g_odd); + let g_avg = + _mm_srai_epi16(_mm_add_epi16(_mm_packs_epi32(g_sum, zero), _mm_set1_epi16(2)), 2); + + // B channel. + let b0_lo = _mm_and_si128(_mm_srli_epi32(px0_lo, 16), channel_mask); + let b0_hi = _mm_and_si128(_mm_srli_epi32(px0_hi, 16), channel_mask); + let b1_lo = _mm_and_si128(_mm_srli_epi32(px1_lo, 16), channel_mask); + let b1_hi = _mm_and_si128(_mm_srli_epi32(px1_hi, 16), channel_mask); + let b_v_lo = _mm_add_epi32(b0_lo, b1_lo); + let b_v_hi = _mm_add_epi32(b0_hi, b1_hi); + let b_v = _mm_packs_epi32(b_v_lo, b_v_hi); + let b_even = _mm_and_si128(b_v, _mm_set_epi16(0, -1, 0, -1, 0, -1, 0, -1)); + let b_odd = _mm_srli_epi32(b_v, 16); + let b_sum = _mm_add_epi16(b_even, b_odd); + let b_avg = + _mm_srai_epi16(_mm_add_epi16(_mm_packs_epi32(b_sum, zero), _mm_set1_epi16(2)), 2); + + // Cb / Cr coefficient multiplies. + let cb_result = _mm_add_epi16( + _mm_srai_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_mullo_epi16(coeff_cb_r, r_avg), + _mm_mullo_epi16(coeff_cb_g, g_avg), + ), + _mm_mullo_epi16(coeff_cb_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + let cr_result = _mm_add_epi16( + _mm_srai_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_add_epi16( + _mm_mullo_epi16(coeff_cr_r, r_avg), + _mm_mullo_epi16(coeff_cr_g, g_avg), + ), + _mm_mullo_epi16(coeff_cr_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + // Pack to u8 and interleave: [U0,V0,U1,V1,U2,V2,U3,V3]. + let cb_packed = _mm_packus_epi16(cb_result, zero); + let cr_packed = _mm_packus_epi16(cr_result, zero); + let interleaved = _mm_unpacklo_epi8(cb_packed, cr_packed); + + _mm_storel_epi64(uv_out.as_mut_ptr().add(ccol * 2).cast(), interleaved); + + ccol += 4; + } + ccol +} + +// ── RGBA8 → NV12 chroma row (AVX2: 8 interleaved UV pairs / iter) ────────── + +/// Convert one pair of RGBA8 rows to interleaved `[U, V, U, V, …]` +/// chroma samples for NV12 output, using AVX2. +/// +/// Processes 8 chroma pairs (16 luma pixels) per iteration. +/// +/// Returns the number of chroma *pairs* converted (multiple of 8). +#[target_feature(enable = "avx2")] +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] +pub(super) unsafe fn rgba8_to_chroma_row_nv12_avx2( + rgba_row0: &[u8], + rgba_row1: &[u8], + uv_out: &mut [u8], + chroma_width: usize, + luma_width: usize, +) -> usize { + use std::arch::x86_64::{ + _mm256_add_epi16, _mm256_add_epi32, _mm256_and_si256, _mm256_castsi256_si128, + _mm256_extracti128_si256, _mm256_loadu_si256, _mm256_mullo_epi16, _mm256_packs_epi32, + _mm256_packus_epi16, _mm256_permute4x64_epi64, _mm256_set1_epi16, _mm256_set1_epi32, + _mm256_set_epi16, _mm256_setzero_si256, _mm256_srai_epi16, _mm256_srli_epi32, + _mm_storeu_si128, _mm_unpacklo_epi32, _mm_unpacklo_epi8, + }; + + let simd_width = chroma_width & !7; + if simd_width == 0 || luma_width < 16 { + return 0; + } + + let coeff_cb_r = _mm256_set1_epi16(-38); + let coeff_cb_g = _mm256_set1_epi16(-74); + let coeff_cb_b = _mm256_set1_epi16(112); + let coeff_cr_r = _mm256_set1_epi16(112); + let coeff_cr_g = _mm256_set1_epi16(-94); + let coeff_cr_b = _mm256_set1_epi16(-18); + let rounding = _mm256_set1_epi16(128); + let bias_128 = _mm256_set1_epi16(128); + let zero = _mm256_setzero_si256(); + let channel_mask = _mm256_set1_epi32(0xFF); + let even_mask = _mm256_set_epi16(0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1); + + let mut ccol = 0usize; + while ccol < simd_width { + let luma_col = ccol * 2; + if luma_col + 16 > luma_width { + break; + } + + let ptr0 = rgba_row0.as_ptr().add(luma_col * 4); + let ptr1 = rgba_row1.as_ptr().add(luma_col * 4); + let px0_a = _mm256_loadu_si256(ptr0.cast()); + let px0_b = _mm256_loadu_si256(ptr0.add(32).cast()); + let px1_a = _mm256_loadu_si256(ptr1.cast()); + let px1_b = _mm256_loadu_si256(ptr1.add(32).cast()); + + // 2×2 average for R channel. + let r0_a = _mm256_and_si256(px0_a, channel_mask); + let r0_b = _mm256_and_si256(px0_b, channel_mask); + let r1_a = _mm256_and_si256(px1_a, channel_mask); + let r1_b = _mm256_and_si256(px1_b, channel_mask); + let r_v_a = _mm256_add_epi32(r0_a, r1_a); + let r_v_b = _mm256_add_epi32(r0_b, r1_b); + let r_v = _mm256_permute4x64_epi64(_mm256_packs_epi32(r_v_a, r_v_b), 0xD8); + let r_even = _mm256_and_si256(r_v, even_mask); + let r_odd = _mm256_srli_epi32(r_v, 16); + let r_sum = _mm256_add_epi16(r_even, r_odd); + let r_avg = _mm256_srai_epi16( + _mm256_add_epi16(_mm256_packs_epi32(r_sum, zero), _mm256_set1_epi16(2)), + 2, + ); + + // G channel. + let g0_a = _mm256_and_si256(_mm256_srli_epi32(px0_a, 8), channel_mask); + let g0_b = _mm256_and_si256(_mm256_srli_epi32(px0_b, 8), channel_mask); + let g1_a = _mm256_and_si256(_mm256_srli_epi32(px1_a, 8), channel_mask); + let g1_b = _mm256_and_si256(_mm256_srli_epi32(px1_b, 8), channel_mask); + let g_v_a = _mm256_add_epi32(g0_a, g1_a); + let g_v_b = _mm256_add_epi32(g0_b, g1_b); + let g_v = _mm256_permute4x64_epi64(_mm256_packs_epi32(g_v_a, g_v_b), 0xD8); + let g_even = _mm256_and_si256(g_v, even_mask); + let g_odd = _mm256_srli_epi32(g_v, 16); + let g_sum = _mm256_add_epi16(g_even, g_odd); + let g_avg = _mm256_srai_epi16( + _mm256_add_epi16(_mm256_packs_epi32(g_sum, zero), _mm256_set1_epi16(2)), + 2, + ); + + // B channel. + let b0_a = _mm256_and_si256(_mm256_srli_epi32(px0_a, 16), channel_mask); + let b0_b = _mm256_and_si256(_mm256_srli_epi32(px0_b, 16), channel_mask); + let b1_a = _mm256_and_si256(_mm256_srli_epi32(px1_a, 16), channel_mask); + let b1_b = _mm256_and_si256(_mm256_srli_epi32(px1_b, 16), channel_mask); + let b_v_a = _mm256_add_epi32(b0_a, b1_a); + let b_v_b = _mm256_add_epi32(b0_b, b1_b); + let b_v = _mm256_permute4x64_epi64(_mm256_packs_epi32(b_v_a, b_v_b), 0xD8); + let b_even = _mm256_and_si256(b_v, even_mask); + let b_odd = _mm256_srli_epi32(b_v, 16); + let b_sum = _mm256_add_epi16(b_even, b_odd); + let b_avg = _mm256_srai_epi16( + _mm256_add_epi16(_mm256_packs_epi32(b_sum, zero), _mm256_set1_epi16(2)), + 2, + ); + + // Cb / Cr coefficient multiplies. + let cb_result = _mm256_add_epi16( + _mm256_srai_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(coeff_cb_r, r_avg), + _mm256_mullo_epi16(coeff_cb_g, g_avg), + ), + _mm256_mullo_epi16(coeff_cb_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + let cr_result = _mm256_add_epi16( + _mm256_srai_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(coeff_cr_r, r_avg), + _mm256_mullo_epi16(coeff_cr_g, g_avg), + ), + _mm256_mullo_epi16(coeff_cr_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + let cb_packed = _mm256_packus_epi16(cb_result, zero); + let cr_packed = _mm256_packus_epi16(cr_result, zero); + let cb_lo = _mm256_castsi256_si128(cb_packed); + let cb_hi = _mm256_extracti128_si256(cb_packed, 1); + let cr_lo = _mm256_castsi256_si128(cr_packed); + let cr_hi = _mm256_extracti128_si256(cr_packed, 1); + let cb8 = _mm_unpacklo_epi32(cb_lo, cb_hi); + let cr8 = _mm_unpacklo_epi32(cr_lo, cr_hi); + let interleaved = _mm_unpacklo_epi8(cb8, cr8); + + _mm_storeu_si128(uv_out.as_mut_ptr().add(ccol * 2).cast(), interleaved); + + ccol += 8; + } + ccol +} + +// ── RGBA8 → I420 chroma row (AVX2: 8 chroma samples / iter) ──────────────── + +/// Convert one pair of RGBA8 rows to U and V chroma samples using AVX2. +/// +/// Processes 8 chroma samples (16 luma pixels) per iteration. +/// +/// Returns the number of chroma samples converted (multiple of 8). +#[target_feature(enable = "avx2")] +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::similar_names)] +pub(super) unsafe fn rgba8_to_chroma_row_avx2( + rgba_row0: &[u8], + rgba_row1: &[u8], + u_out: &mut [u8], + v_out: &mut [u8], + chroma_width: usize, + luma_width: usize, +) -> usize { + use std::arch::x86_64::{ + _mm256_add_epi16, _mm256_add_epi32, _mm256_and_si256, _mm256_castsi256_si128, + _mm256_extracti128_si256, _mm256_loadu_si256, _mm256_mullo_epi16, _mm256_packs_epi32, + _mm256_packus_epi16, _mm256_permute4x64_epi64, _mm256_set1_epi16, _mm256_set1_epi32, + _mm256_set_epi16, _mm256_setzero_si256, _mm256_srai_epi16, _mm256_srli_epi32, + _mm_storel_epi64, _mm_unpacklo_epi32, + }; + + let simd_width = chroma_width & !7; + if simd_width == 0 || luma_width < 16 { + return 0; + } + + let coeff_cb_r = _mm256_set1_epi16(-38); + let coeff_cb_g = _mm256_set1_epi16(-74); + let coeff_cb_b = _mm256_set1_epi16(112); + let coeff_cr_r = _mm256_set1_epi16(112); + let coeff_cr_g = _mm256_set1_epi16(-94); + let coeff_cr_b = _mm256_set1_epi16(-18); + let rounding = _mm256_set1_epi16(128); + let bias_128 = _mm256_set1_epi16(128); + let zero = _mm256_setzero_si256(); + let channel_mask = _mm256_set1_epi32(0xFF); + let even_mask = _mm256_set_epi16(0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1); + + let mut ccol = 0usize; + while ccol < simd_width { + let luma_col = ccol * 2; + if luma_col + 16 > luma_width { + break; + } + + let ptr0 = rgba_row0.as_ptr().add(luma_col * 4); + let ptr1 = rgba_row1.as_ptr().add(luma_col * 4); + let px0_a = _mm256_loadu_si256(ptr0.cast()); + let px0_b = _mm256_loadu_si256(ptr0.add(32).cast()); + let px1_a = _mm256_loadu_si256(ptr1.cast()); + let px1_b = _mm256_loadu_si256(ptr1.add(32).cast()); + + // 2×2 average for R channel. + let r0_a = _mm256_and_si256(px0_a, channel_mask); + let r0_b = _mm256_and_si256(px0_b, channel_mask); + let r1_a = _mm256_and_si256(px1_a, channel_mask); + let r1_b = _mm256_and_si256(px1_b, channel_mask); + let r_v_a = _mm256_add_epi32(r0_a, r1_a); + let r_v_b = _mm256_add_epi32(r0_b, r1_b); + let r_v = _mm256_permute4x64_epi64(_mm256_packs_epi32(r_v_a, r_v_b), 0xD8); + let r_even = _mm256_and_si256(r_v, even_mask); + let r_odd = _mm256_srli_epi32(r_v, 16); + let r_sum = _mm256_add_epi16(r_even, r_odd); + let r_avg = _mm256_srai_epi16( + _mm256_add_epi16(_mm256_packs_epi32(r_sum, zero), _mm256_set1_epi16(2)), + 2, + ); + + // G channel. + let g0_a = _mm256_and_si256(_mm256_srli_epi32(px0_a, 8), channel_mask); + let g0_b = _mm256_and_si256(_mm256_srli_epi32(px0_b, 8), channel_mask); + let g1_a = _mm256_and_si256(_mm256_srli_epi32(px1_a, 8), channel_mask); + let g1_b = _mm256_and_si256(_mm256_srli_epi32(px1_b, 8), channel_mask); + let g_v_a = _mm256_add_epi32(g0_a, g1_a); + let g_v_b = _mm256_add_epi32(g0_b, g1_b); + let g_v = _mm256_permute4x64_epi64(_mm256_packs_epi32(g_v_a, g_v_b), 0xD8); + let g_even = _mm256_and_si256(g_v, even_mask); + let g_odd = _mm256_srli_epi32(g_v, 16); + let g_sum = _mm256_add_epi16(g_even, g_odd); + let g_avg = _mm256_srai_epi16( + _mm256_add_epi16(_mm256_packs_epi32(g_sum, zero), _mm256_set1_epi16(2)), + 2, + ); + + // B channel. + let b0_a = _mm256_and_si256(_mm256_srli_epi32(px0_a, 16), channel_mask); + let b0_b = _mm256_and_si256(_mm256_srli_epi32(px0_b, 16), channel_mask); + let b1_a = _mm256_and_si256(_mm256_srli_epi32(px1_a, 16), channel_mask); + let b1_b = _mm256_and_si256(_mm256_srli_epi32(px1_b, 16), channel_mask); + let b_v_a = _mm256_add_epi32(b0_a, b1_a); + let b_v_b = _mm256_add_epi32(b0_b, b1_b); + let b_v = _mm256_permute4x64_epi64(_mm256_packs_epi32(b_v_a, b_v_b), 0xD8); + let b_even = _mm256_and_si256(b_v, even_mask); + let b_odd = _mm256_srli_epi32(b_v, 16); + let b_sum = _mm256_add_epi16(b_even, b_odd); + let b_avg = _mm256_srai_epi16( + _mm256_add_epi16(_mm256_packs_epi32(b_sum, zero), _mm256_set1_epi16(2)), + 2, + ); + + // Cb / Cr coefficient multiplies. + let cb_result = _mm256_add_epi16( + _mm256_srai_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(coeff_cb_r, r_avg), + _mm256_mullo_epi16(coeff_cb_g, g_avg), + ), + _mm256_mullo_epi16(coeff_cb_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + let cr_result = _mm256_add_epi16( + _mm256_srai_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_add_epi16( + _mm256_mullo_epi16(coeff_cr_r, r_avg), + _mm256_mullo_epi16(coeff_cr_g, g_avg), + ), + _mm256_mullo_epi16(coeff_cr_b, b_avg), + ), + rounding, + ), + 8, + ), + bias_128, + ); + + let cb_packed = _mm256_packus_epi16(cb_result, zero); + let cr_packed = _mm256_packus_epi16(cr_result, zero); + let cb_lo = _mm256_castsi256_si128(cb_packed); + let cb_hi = _mm256_extracti128_si256(cb_packed, 1); + let cr_lo = _mm256_castsi256_si128(cr_packed); + let cr_hi = _mm256_extracti128_si256(cr_packed, 1); + let cb8 = _mm_unpacklo_epi32(cb_lo, cb_hi); + let cr8 = _mm_unpacklo_epi32(cr_lo, cr_hi); + + _mm_storel_epi64(u_out.as_mut_ptr().add(ccol).cast(), cb8); + _mm_storel_epi64(v_out.as_mut_ptr().add(ccol).cast(), cr8); + + ccol += 8; + } + ccol +} diff --git a/crates/nodes/src/video/fonts.rs b/crates/nodes/src/video/fonts.rs new file mode 100644 index 00000000..14588559 --- /dev/null +++ b/crates/nodes/src/video/fonts.rs @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Compile-time embedded font data for the bundled font set. +//! +//! All fonts in [`BUNDLED_FONTS`] are included in the binary via +//! `include_bytes!` so they work without any system font packages +//! installed. The DejaVu family is distributed under the permissive +//! Bitstream Vera / DejaVu license (see `assets/fonts/LICENSE-DejaVu.txt`). + +/// A font embedded in the binary at compile time. +pub struct BundledFont { + /// User-facing name used in `font_name` config fields. + pub name: &'static str, + /// Raw TTF bytes baked into the binary. + pub data: &'static [u8], +} + +/// Bundled font set — always available, no filesystem dependency. +/// +/// Order matters: the first entry is the default proportional font and +/// the third entry is the default monospace font (see [`DEFAULT_FONT`] +/// and [`DEFAULT_MONO_FONT`]). +pub static BUNDLED_FONTS: &[BundledFont] = &[ + BundledFont { + name: "dejavu-sans", + data: include_bytes!("../../../../assets/fonts/DejaVuSans.ttf"), + }, + BundledFont { + name: "dejavu-sans-bold", + data: include_bytes!("../../../../assets/fonts/DejaVuSans-Bold.ttf"), + }, + BundledFont { + name: "dejavu-sans-mono", + data: include_bytes!("../../../../assets/fonts/DejaVuSansMono.ttf"), + }, + BundledFont { + name: "dejavu-sans-mono-bold", + data: include_bytes!("../../../../assets/fonts/DejaVuSansMono-Bold.ttf"), + }, + BundledFont { + name: "dejavu-serif", + data: include_bytes!("../../../../assets/fonts/DejaVuSerif.ttf"), + }, + BundledFont { + name: "dejavu-serif-bold", + data: include_bytes!("../../../../assets/fonts/DejaVuSerif-Bold.ttf"), + }, +]; + +/// Default proportional font bytes (DejaVu Sans) — used when no font is +/// specified in compositor text overlays. +pub static DEFAULT_FONT_DATA: &[u8] = include_bytes!("../../../../assets/fonts/DejaVuSans.ttf"); + +/// Default monospace font bytes (DejaVu Sans Mono) — used by the colorbars +/// `draw_time` overlay. +pub static DEFAULT_MONO_FONT_DATA: &[u8] = + include_bytes!("../../../../assets/fonts/DejaVuSansMono.ttf"); + +/// Look up a bundled font by its user-facing name. +/// +/// Returns `None` if the name is not in the bundled set. +pub fn bundled_font_by_name(name: &str) -> Option<&'static [u8]> { + BUNDLED_FONTS.iter().find(|f| f.name == name).map(|f| f.data) +} + +/// Comma-separated list of bundled font names (for error messages). +pub fn bundled_font_names() -> String { + BUNDLED_FONTS.iter().map(|f| f.name).collect::>().join(", ") +} diff --git a/crates/nodes/src/video/mod.rs b/crates/nodes/src/video/mod.rs new file mode 100644 index 00000000..4cbb6d86 --- /dev/null +++ b/crates/nodes/src/video/mod.rs @@ -0,0 +1,358 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Video nodes and registration. + +use streamkit_core::types::PixelFormat; +use streamkit_core::{NodeRegistry, StreamKitError}; + +/// Default video frame duration in microseconds (~30 fps). +/// +/// Used as a fallback when incoming packets carry no duration metadata. +/// Shared across WebM muxing and MoQ transport. +pub const DEFAULT_VIDEO_FRAME_DURATION_US: u64 = 33_333; + +/// Parse a pixel format string into a [`PixelFormat`]. +/// +/// Accepts `"i420"`, `"nv12"`, `"rgba8"`, or `"rgba"` (case-insensitive). +/// +/// # Errors +/// +/// Returns [`StreamKitError::Configuration`] if `s` is not a recognised format name. +pub fn parse_pixel_format(s: &str) -> Result { + match s.to_lowercase().as_str() { + "i420" => Ok(PixelFormat::I420), + "nv12" => Ok(PixelFormat::Nv12), + "rgba8" | "rgba" => Ok(PixelFormat::Rgba8), + other => Err(StreamKitError::Configuration(format!( + "Unsupported pixel format '{other}'. Use 'i420', 'nv12', or 'rgba8'." + ))), + } +} + +#[cfg(feature = "colorbars")] +pub mod colorbars; + +#[cfg(feature = "compositor")] +pub mod compositor; + +#[cfg(feature = "compositor")] +pub mod pixel_convert; + +#[cfg(feature = "vp9")] +pub mod vp9; + +#[cfg(any(feature = "colorbars", feature = "compositor"))] +pub(crate) mod fonts; + +// ── Shared font-rendering helpers ──────────────────────────────────────────── + +/// Measure the pixel dimensions a single-line text string would occupy when +/// rendered at `font_size`. Returns `(width, height)`. +/// +/// The width is the sum of advance widths. The height uses the same baseline +/// logic as [`blit_text_rgba`] and adds enough room for descenders. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss, + clippy::cast_precision_loss +)] +pub fn measure_text(font: &fontdue::Font, font_size: f32, text: &str) -> (u32, u32) { + if text.is_empty() { + return (0, 0); + } + + let (ref_metrics, _) = font.rasterize('A', font_size); + let baseline_y = ref_metrics.height as f32; + + let mut total_width: f32 = 0.0; + let mut max_top: i32 = 0; // highest pixel above origin_y (always >= 0) + let mut max_bottom: i32 = 0; // lowest pixel below origin_y + + for ch in text.chars() { + let (metrics, _) = font.rasterize(ch, font_size); + + let gy = (baseline_y - metrics.ymin as f32) as i32 - metrics.height as i32; + let glyph_bottom = gy + metrics.height as i32; + + if gy < max_top { + max_top = gy; + } + if glyph_bottom > max_bottom { + max_bottom = glyph_bottom; + } + + total_width += metrics.advance_width; + } + + let w = total_width.ceil() as u32; + let h = + if max_bottom > max_top { (max_bottom - max_top) as u32 } else { font_size.ceil() as u32 }; + + (w, h) +} + +/// Alpha-blend a single text string into a packed RGBA8 buffer. +/// +/// `origin_x` / `origin_y` are the top-left pixel coordinates where the first +/// glyph begins. `color` is `[R, G, B, A]` — the alpha component modulates +/// coverage so semi-transparent text is supported. +/// +/// The function clips to the buffer dimensions and stops early if the cursor +/// advances past `buf_width`. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss, + clippy::cast_precision_loss, + clippy::too_many_arguments +)] +pub fn blit_text_rgba( + buf: &mut [u8], + buf_width: u32, + buf_height: u32, + font: &fontdue::Font, + font_size: f32, + text: &str, + origin_x: i32, + origin_y: i32, + color: [u8; 4], +) { + let [cr, cg, cb, ca] = color; + let stride = buf_width as usize * 4; + + // Establish baseline from a reference glyph. + let (ref_metrics, _) = font.rasterize('A', font_size); + let baseline_y = ref_metrics.height as f32; + + let mut cursor_x: f32 = 0.0; + + for ch in text.chars() { + let (metrics, bitmap) = font.rasterize(ch, font_size); + + let gx = origin_x + (cursor_x + metrics.xmin as f32) as i32; + let gy = origin_y + (baseline_y - metrics.ymin as f32) as i32 - metrics.height as i32; + + for row in 0..metrics.height { + let dst_y = gy + row as i32; + if dst_y < 0 || dst_y >= buf_height as i32 { + continue; + } + for col in 0..metrics.width { + let dst_x = gx + col as i32; + if dst_x < 0 || dst_x >= buf_width as i32 { + continue; + } + let coverage = bitmap[row * metrics.width + col]; + if coverage == 0 { + continue; + } + + let alpha = u16::from(ca) * u16::from(coverage) / 255; + if alpha == 0 { + continue; + } + let off = dst_y as usize * stride + dst_x as usize * 4; + + if alpha >= 255 { + buf[off] = cr; + buf[off + 1] = cg; + buf[off + 2] = cb; + buf[off + 3] = 255; + } else { + let inv = 255 - alpha; + let dr = u16::from(buf[off]); + let dg = u16::from(buf[off + 1]); + let db = u16::from(buf[off + 2]); + let da = u16::from(buf[off + 3]); + buf[off] = ((u16::from(cr) * alpha + dr * inv + 128) / 255) as u8; + buf[off + 1] = ((u16::from(cg) * alpha + dg * inv + 128) / 255) as u8; + buf[off + 2] = ((u16::from(cb) * alpha + db * inv + 128) / 255) as u8; + buf[off + 3] = (alpha + (da * inv + 128) / 255).min(255) as u8; + } + } + } + + cursor_x += metrics.advance_width; + if (origin_x as f32 + cursor_x) >= buf_width as f32 { + break; + } + } +} + +// ── Multi-line / word-wrap helpers ──────────────────────────────────────── + +/// Split `text` into wrapped lines that fit within `max_width` pixels. +/// +/// Explicit `\n` characters always produce a line break. Within each +/// paragraph the text is word-wrapped (split on ASCII whitespace) so that +/// no line exceeds `max_width`. When a single word is wider than +/// `max_width` it is placed on its own line without further splitting. +/// +/// If `max_width` is 0 the text is only split on explicit newlines (no +/// word-wrapping). +#[allow(clippy::cast_precision_loss)] +fn wrap_text_lines( + font: &fontdue::Font, + font_size: f32, + text: &str, + max_width: u32, +) -> Vec { + let paragraphs: Vec<&str> = text.split('\n').collect(); + + if max_width == 0 { + return paragraphs.iter().map(|s| (*s).to_string()).collect(); + } + + let max_w = max_width as f32; + let space_advance = { + let (m, _) = font.rasterize(' ', font_size); + m.advance_width + }; + + let mut lines = Vec::new(); + + for paragraph in paragraphs { + if paragraph.is_empty() { + lines.push(String::new()); + continue; + } + + let words: Vec<&str> = paragraph.split_whitespace().collect(); + if words.is_empty() { + lines.push(String::new()); + continue; + } + + let mut current_line = String::new(); + let mut current_width: f32 = 0.0; + + for word in &words { + let (word_w, _) = measure_text(font, font_size, word); + let word_w_f = word_w as f32; + + if current_line.is_empty() { + // First word on the line — always accept it. + current_line.push_str(word); + current_width = word_w_f; + } else if current_width + space_advance + word_w_f <= max_w { + // Fits on the current line. + current_line.push(' '); + current_line.push_str(word); + current_width += space_advance + word_w_f; + } else { + // Doesn't fit — flush current line and start a new one. + lines.push(std::mem::take(&mut current_line)); + current_line.push_str(word); + current_width = word_w_f; + } + } + + if !current_line.is_empty() { + lines.push(current_line); + } + } + + if lines.is_empty() { + lines.push(String::new()); + } + + lines +} + +/// Measure the pixel dimensions of multi-line wrapped text. +/// +/// Splits the input on explicit newlines and word-wraps each paragraph to +/// fit within `max_width` pixels (see [`wrap_text_lines`]). Returns the +/// bounding `(width, height)` of the full block of text. +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_precision_loss)] +pub fn measure_text_wrapped( + font: &fontdue::Font, + font_size: f32, + text: &str, + max_width: u32, +) -> (u32, u32) { + if text.is_empty() { + return (0, 0); + } + + let lines = wrap_text_lines(font, font_size, text, max_width); + let line_height = line_height_px(font, font_size); + + let mut widest: u32 = 0; + for line in &lines { + if line.is_empty() { + continue; + } + let (w, _) = measure_text(font, font_size, line); + if w > widest { + widest = w; + } + } + + let total_h = (lines.len() as f32 * line_height).ceil() as u32; + (widest, total_h) +} + +/// Blit multi-line wrapped text into a packed RGBA8 buffer. +/// +/// The text is split on explicit `\n` and word-wrapped to `max_width` +/// pixels (see [`wrap_text_lines`]). Each resulting line is rendered via +/// [`blit_text_rgba`] at successive vertical offsets. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_precision_loss, + clippy::too_many_arguments +)] +pub fn blit_text_wrapped( + buf: &mut [u8], + buf_width: u32, + buf_height: u32, + font: &fontdue::Font, + font_size: f32, + text: &str, + origin_x: i32, + origin_y: i32, + color: [u8; 4], + max_width: u32, +) { + let lines = wrap_text_lines(font, font_size, text, max_width); + let line_height = line_height_px(font, font_size); + + for (i, line) in lines.iter().enumerate() { + if line.is_empty() { + continue; + } + let y = origin_y + (i as f32 * line_height).round() as i32; + blit_text_rgba(buf, buf_width, buf_height, font, font_size, line, origin_x, y, color); + } +} + +/// Compute the line height (in pixels) for a font at the given size. +/// +/// Uses a 1.2× multiplier on the reference glyph height, matching the +/// `line-height: 1.2` used in the UI's `CompositorCanvas`. +#[allow(clippy::cast_precision_loss)] +fn line_height_px(font: &fontdue::Font, font_size: f32) -> f32 { + let (ref_metrics, _) = font.rasterize('A', font_size); + ref_metrics.height as f32 * 1.2 +} + +/// Registers all available video nodes with the engine's registry. +#[allow(clippy::missing_const_for_fn)] +pub fn register_video_nodes(registry: &mut NodeRegistry) { + #[cfg(feature = "colorbars")] + colorbars::register_colorbars_nodes(registry); + + #[cfg(feature = "compositor")] + compositor::register_compositor_nodes(registry); + + #[cfg(feature = "compositor")] + pixel_convert::register_pixel_convert_nodes(registry); + + #[cfg(feature = "vp9")] + vp9::register_vp9_nodes(registry); +} diff --git a/crates/nodes/src/video/pixel_convert.rs b/crates/nodes/src/video/pixel_convert.rs new file mode 100644 index 00000000..8be6fc41 --- /dev/null +++ b/crates/nodes/src/video/pixel_convert.rs @@ -0,0 +1,675 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Pixel-format conversion node. +//! +//! Converts raw video frames between RGBA8, NV12, and I420 pixel formats. +//! Runs the CPU-heavy conversion on a persistent `spawn_blocking` thread +//! (same pattern as the VP9 encoder) and caches the output when the input +//! `Arc` pointer hasn't changed (zero-cost passthrough for +//! static scenes). +//! +//! Supported conversions: +//! - RGBA8 → NV12 +//! - RGBA8 → I420 +//! - NV12 → RGBA8 +//! - I420 → RGBA8 +//! +//! Unsupported pairs (e.g. NV12 ↔ I420) return an error rather than +//! silently chaining two conversions. + +use async_trait::async_trait; +use opentelemetry::{global, KeyValue}; +use schemars::JsonSchema; +use serde::Deserialize; +use std::sync::Arc; +use std::time::Instant; +use streamkit_core::control::NodeControlMessage; +use streamkit_core::stats::NodeStatsTracker; +use streamkit_core::types::{Packet, PacketType, PixelFormat, VideoFormat, VideoFrame}; +use streamkit_core::{ + config_helpers, get_codec_channel_capacity, packet_helpers, state_helpers, InputPin, + NodeContext, NodeRegistry, OutputPin, PinCardinality, PooledVideoData, ProcessorNode, + StreamKitError, +}; +use tokio::sync::mpsc; + +use super::parse_pixel_format; +use crate::video::compositor::pixel_ops::{ + i420_to_rgba8_buf, nv12_to_rgba8_buf, rgba8_to_i420_buf, rgba8_to_nv12_buf, +}; + +// ── Config ────────────────────────────────────────────────────────────────── + +/// Configuration for the pixel format converter node. +#[derive(Deserialize, Debug, Clone, JsonSchema)] +#[serde(default)] +pub struct PixelConvertConfig { + /// Target pixel format: `"nv12"` (default), `"i420"`, or `"rgba8"`. + pub output_format: String, +} + +impl Default for PixelConvertConfig { + fn default() -> Self { + Self { output_format: "nv12".to_string() } + } +} + +// ── Node ──────────────────────────────────────────────────────────────────── + +/// Converts raw video frames between pixel formats (RGBA8, NV12, I420). +/// +/// When the input format already matches the target, the frame is forwarded +/// unchanged (zero allocation). When the input `Arc` +/// pointer is identical to the previous frame, the cached converted output +/// is re-sent (ref-count bump only, no conversion). +pub struct PixelConvertNode { + target_format: PixelFormat, +} + +impl PixelConvertNode { + /// Create a new pixel convert node with the given configuration. + /// + /// # Errors + /// + /// Returns an error if `output_format` is not a recognised pixel format. + pub fn new(config: &PixelConvertConfig) -> Result { + let target_format = parse_pixel_format(&config.output_format)?; + Ok(Self { target_format }) + } +} + +#[async_trait] +impl ProcessorNode for PixelConvertNode { + fn input_pins(&self) -> Vec { + vec![InputPin { + name: "in".to_string(), + accepts_types: vec![ + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::Rgba8, + }), + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::I420, + }), + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::Nv12, + }), + ], + cardinality: PinCardinality::One, + }] + } + + fn output_pins(&self) -> Vec { + vec![OutputPin { + name: "out".to_string(), + produces_type: PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: self.target_format, + }), + cardinality: PinCardinality::Broadcast, + }] + } + + #[allow(clippy::too_many_lines)] + async fn run(self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { + let node_name = context.output_sender.node_name().to_string(); + state_helpers::emit_initializing(&context.state_tx, &node_name); + + let mut input_rx = context.take_input("in")?; + + tracing::info!("PixelConvertNode starting: target_format={:?}", self.target_format); + + // ── Blocking conversion thread ────────────────────────────────── + let target_format = self.target_format; + let otel_node_name = node_name.clone(); + let video_pool = context.video_pool.clone(); + + let (convert_tx, mut convert_rx) = + mpsc::channel::(get_codec_channel_capacity()); + let (result_tx, mut result_rx) = + mpsc::channel::>(get_codec_channel_capacity()); + + let convert_task = tokio::task::spawn_blocking(move || { + // OpenTelemetry metrics — created here so they are owned by the + // blocking thread and can be updated without cross-thread sync. + let meter = global::meter("skit_nodes"); + let frames_converted_counter = meter + .u64_counter("pixel_convert.frames_converted") + .with_description("Frames that required pixel format conversion") + .build(); + let frames_passthrough_counter = meter + .u64_counter("pixel_convert.frames_passthrough") + .with_description("Frames forwarded unchanged (same format or same Arc pointer)") + .build(); + let conversion_duration_histogram = meter + .f64_histogram("pixel_convert.conversion_duration") + .with_description("Seconds per pixel format conversion") + .with_boundaries( + streamkit_core::metrics::HISTOGRAM_BOUNDARIES_CODEC_PACKET.to_vec(), + ) + .build(); + let otel_attrs = [KeyValue::new("node", otel_node_name)]; + + let mut last_input_ptr: usize = 0; + let mut cached_output: Option> = None; + let mut cached_output_format: Option = None; + let mut cached_width: u32 = 0; + let mut cached_height: u32 = 0; + + while let Some(frame) = convert_rx.blocking_recv() { + // Fast path 1: format already matches — passthrough. + if frame.pixel_format == target_format { + frames_passthrough_counter.add(1, &otel_attrs); + if result_tx.blocking_send(Ok(frame)).is_err() { + break; + } + continue; + } + + // Fast path 2: identical Arc pointer — re-send cached output. + let current_ptr = Arc::as_ptr(&frame.data) as usize; + if current_ptr == last_input_ptr + && last_input_ptr != 0 + && cached_output.is_some() + && cached_output_format == Some(target_format) + && cached_width == frame.width + && cached_height == frame.height + { + #[allow(clippy::unwrap_used)] // guarded by is_some() check above + let cached_data = cached_output.clone().unwrap(); + let result = VideoFrame::from_arc( + frame.width, + frame.height, + target_format, + cached_data, + frame.metadata.clone(), + ); + match result { + Ok(out_frame) => { + frames_passthrough_counter.add(1, &otel_attrs); + if result_tx.blocking_send(Ok(out_frame)).is_err() { + break; + } + }, + Err(err) => { + let _ = result_tx.blocking_send(Err(err.to_string())); + }, + } + continue; + } + + // Slow path: perform conversion. + let convert_start = Instant::now(); + let result = convert_frame(&frame, target_format, video_pool.as_deref()); + let duration = convert_start.elapsed(); + + match result { + Ok(out_frame) => { + frames_converted_counter.add(1, &otel_attrs); + conversion_duration_histogram.record(duration.as_secs_f64(), &otel_attrs); + + // Update cache. + last_input_ptr = current_ptr; + cached_output = Some(Arc::clone(&out_frame.data)); + cached_output_format = Some(target_format); + cached_width = out_frame.width; + cached_height = out_frame.height; + + if result_tx.blocking_send(Ok(out_frame)).is_err() { + break; + } + }, + Err(err) => { + let _ = result_tx.blocking_send(Err(err.to_string())); + }, + } + } + }); + + state_helpers::emit_running(&context.state_tx, &node_name); + + let mut stats_tracker = NodeStatsTracker::new(node_name.clone(), context.stats_tx.clone()); + let batch_size = context.batch_size; + + let convert_tx_clone = convert_tx.clone(); + let mut input_task = tokio::spawn(async move { + loop { + let Some(first_packet) = input_rx.recv().await else { + break; + }; + + let packet_batch = + packet_helpers::batch_packets_greedy(first_packet, &mut input_rx, batch_size); + + for packet in packet_batch { + if let Packet::Video(frame) = packet { + if convert_tx_clone.send(frame).await.is_err() { + tracing::error!( + "PixelConvertNode convert task has shut down unexpectedly" + ); + return; + } + } + } + } + tracing::info!("PixelConvertNode input stream closed"); + }); + + // ── Forward loop (mirrors codec_forward_loop pattern) ─────────── + loop { + tokio::select! { + maybe_result = result_rx.recv() => { + match maybe_result { + Some(Ok(out_frame)) => { + // Determine if this was a passthrough or conversion. + // We check the pixel format: if it matches AND the frame + // wasn't converted, it's a passthrough. + // Since the blocking thread only sends Ok when it + // succeeded, we count based on whether conversion happened. + // The simplest reliable signal: check if the output format + // differs from what we'd expect on passthrough... but + // actually both paths produce target_format. We'll count + // in the blocking thread instead by sending a flag. + // For simplicity, count all successful outputs here and + // rely on the metrics from the blocking thread. + + stats_tracker.received(); + if context + .output_sender + .send("out", Packet::Video(out_frame)) + .await + .is_err() + { + tracing::debug!("Output channel closed, stopping node"); + break; + } + stats_tracker.sent(); + stats_tracker.maybe_send(); + }, + Some(Err(err)) => { + stats_tracker.received(); + stats_tracker.errored(); + stats_tracker.maybe_send(); + tracing::warn!("PixelConvertNode conversion error: {err}"); + }, + None => break, + } + } + Some(control_msg) = context.control_rx.recv() => { + if matches!(control_msg, NodeControlMessage::Shutdown) { + tracing::info!("PixelConvertNode received shutdown signal"); + input_task.abort(); + convert_task.abort(); + drop(convert_tx); + break; + } + } + _ = &mut input_task => { + // Input finished — drop sender to signal blocking thread, + // then drain remaining results. + drop(convert_tx); + while let Some(maybe_result) = result_rx.recv().await { + match maybe_result { + Ok(out_frame) => { + stats_tracker.received(); + if context + .output_sender + .send("out", Packet::Video(out_frame)) + .await + .is_err() + { + break; + } + stats_tracker.sent(); + stats_tracker.maybe_send(); + }, + Err(err) => { + stats_tracker.received(); + stats_tracker.errored(); + stats_tracker.maybe_send(); + tracing::warn!("PixelConvertNode conversion error: {err}"); + }, + } + } + break; + } + } + } + + convert_task.abort(); + let _ = convert_task.await; + + state_helpers::emit_stopped(&context.state_tx, &node_name, "input_closed"); + tracing::info!("PixelConvertNode shutting down."); + Ok(()) + } +} + +// ── Conversion helper ─────────────────────────────────────────────────────── + +/// Convert a `VideoFrame` to the given `target_format`. +/// +/// Allocates the output buffer from `video_pool` when available. +/// +/// # Errors +/// +/// Returns an error for unsupported conversion pairs (e.g. NV12 ↔ I420). +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +fn convert_frame( + frame: &VideoFrame, + target_format: PixelFormat, + video_pool: Option<&streamkit_core::VideoFramePool>, +) -> Result { + let w = frame.width as usize; + let h = frame.height as usize; + + let out_size = match target_format { + PixelFormat::Rgba8 => w * h * 4, + PixelFormat::Nv12 => { + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + w * h + chroma_w * 2 * chroma_h + }, + PixelFormat::I420 => { + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + w * h + chroma_w * chroma_h * 2 + }, + }; + + let mut out_data = video_pool + .map_or_else(|| PooledVideoData::from_vec(vec![0u8; out_size]), |pool| pool.get(out_size)); + + match (frame.pixel_format, target_format) { + (PixelFormat::Rgba8, PixelFormat::Nv12) => { + rgba8_to_nv12_buf(frame.data(), frame.width, frame.height, out_data.as_mut_slice()); + }, + (PixelFormat::Rgba8, PixelFormat::I420) => { + rgba8_to_i420_buf(frame.data(), frame.width, frame.height, out_data.as_mut_slice()); + }, + (PixelFormat::Nv12, PixelFormat::Rgba8) => { + nv12_to_rgba8_buf(frame.data(), frame.width, frame.height, out_data.as_mut_slice()); + }, + (PixelFormat::I420, PixelFormat::Rgba8) => { + i420_to_rgba8_buf(frame.data(), frame.width, frame.height, out_data.as_mut_slice()); + }, + (src, dst) => { + return Err(StreamKitError::Runtime(format!( + "Unsupported pixel format conversion: {src:?} → {dst:?}. \ + Only RGBA8 ↔ NV12 and RGBA8 ↔ I420 are supported." + ))); + }, + } + + VideoFrame::from_pooled( + frame.width, + frame.height, + target_format, + out_data, + frame.metadata.clone(), + ) +} + +// ── Registration ──────────────────────────────────────────────────────────── + +use schemars::schema_for; +use streamkit_core::registry::StaticPins; + +#[allow(clippy::expect_used, clippy::missing_panics_doc)] +pub fn register_pixel_convert_nodes(registry: &mut NodeRegistry) { + let default_node = PixelConvertNode::new(&PixelConvertConfig::default()) + .expect("default PixelConvertConfig should be valid"); + registry.register_static_with_description( + "video::pixel_convert", + |params| { + let config: PixelConvertConfig = config_helpers::parse_config_optional(params)?; + Ok(Box::new(PixelConvertNode::new(&config)?)) + }, + serde_json::to_value(schema_for!(PixelConvertConfig)) + .expect("PixelConvertConfig schema should serialize to JSON"), + StaticPins { inputs: default_node.input_pins(), outputs: default_node.output_pins() }, + vec!["video".to_string(), "convert".to_string()], + false, + "Converts raw video frames between pixel formats (RGBA8, NV12, I420). \ + Insert upstream of nodes that require a specific format (e.g. VP9 encoder). \ + Passthrough when input format already matches the target.", + ); +} + +// ── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +#[allow(clippy::unwrap_used, clippy::expect_used)] +mod tests { + use super::*; + use crate::test_utils::{ + assert_state_initializing, assert_state_running, assert_state_stopped, create_test_context, + create_test_video_frame, + }; + use std::collections::HashMap; + use tokio::sync::mpsc; + + #[tokio::test] + async fn test_passthrough_same_format() { + let (input_tx, input_rx) = mpsc::channel(10); + let mut inputs = HashMap::new(); + inputs.insert("in".to_string(), input_rx); + + let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + + // Target is NV12, send NV12 frames — should passthrough. + let node = PixelConvertNode::new(&PixelConvertConfig { output_format: "nv12".to_string() }) + .unwrap(); + + let node_handle = tokio::spawn(async move { Box::new(node).run(context).await }); + + assert_state_initializing(&mut state_rx).await; + assert_state_running(&mut state_rx).await; + + let frame = create_test_video_frame(64, 64, PixelFormat::Nv12, 128); + let original_data_ptr = Arc::as_ptr(&frame.data) as usize; + input_tx.send(Packet::Video(frame)).await.unwrap(); + + drop(input_tx); + assert_state_stopped(&mut state_rx).await; + node_handle.await.unwrap().unwrap(); + + let output_packets = mock_sender.get_packets_for_pin("out").await; + assert_eq!(output_packets.len(), 1, "Expected 1 output packet"); + + if let Packet::Video(out_frame) = &output_packets[0] { + assert_eq!(out_frame.pixel_format, PixelFormat::Nv12); + // Verify the Arc pointer is identical (zero-copy passthrough). + let out_data_ptr = Arc::as_ptr(&out_frame.data) as usize; + assert_eq!( + original_data_ptr, out_data_ptr, + "Passthrough should preserve the same Arc pointer" + ); + } else { + panic!("Expected Video packet"); + } + } + + #[tokio::test] + async fn test_identical_frame_caching() { + let (input_tx, input_rx) = mpsc::channel(10); + let mut inputs = HashMap::new(); + inputs.insert("in".to_string(), input_rx); + + let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + + let node = PixelConvertNode::new(&PixelConvertConfig { output_format: "nv12".to_string() }) + .unwrap(); + + let node_handle = tokio::spawn(async move { Box::new(node).run(context).await }); + + assert_state_initializing(&mut state_rx).await; + assert_state_running(&mut state_rx).await; + + // Send the same RGBA8 frame twice (clone preserves the Arc). + let frame = create_test_video_frame(64, 64, PixelFormat::Rgba8, 128); + let frame_clone = frame.clone(); + assert_eq!( + Arc::as_ptr(&frame.data) as usize, + Arc::as_ptr(&frame_clone.data) as usize, + "Cloned frames should share the same Arc" + ); + + input_tx.send(Packet::Video(frame)).await.unwrap(); + + // Small delay to ensure the first frame is processed before sending + // the second, so the caching logic can compare Arc pointers. + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + input_tx.send(Packet::Video(frame_clone)).await.unwrap(); + + drop(input_tx); + assert_state_stopped(&mut state_rx).await; + node_handle.await.unwrap().unwrap(); + + let output_packets = mock_sender.get_packets_for_pin("out").await; + assert_eq!(output_packets.len(), 2, "Expected 2 output packets"); + + // Both outputs should be NV12. + for pkt in &output_packets { + if let Packet::Video(f) = pkt { + assert_eq!(f.pixel_format, PixelFormat::Nv12); + } else { + panic!("Expected Video packet"); + } + } + + // The second output should reuse the cached Arc (same pointer). + if let (Packet::Video(f1), Packet::Video(f2)) = (&output_packets[0], &output_packets[1]) { + let ptr1 = Arc::as_ptr(&f1.data) as usize; + let ptr2 = Arc::as_ptr(&f2.data) as usize; + assert_eq!(ptr1, ptr2, "Second frame should reuse cached Arc from first conversion"); + } + } + + #[tokio::test] + async fn test_rgba8_to_nv12_conversion() { + let (input_tx, input_rx) = mpsc::channel(10); + let mut inputs = HashMap::new(); + inputs.insert("in".to_string(), input_rx); + + let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + + let node = PixelConvertNode::new(&PixelConvertConfig { output_format: "nv12".to_string() }) + .unwrap(); + + let node_handle = tokio::spawn(async move { Box::new(node).run(context).await }); + + assert_state_initializing(&mut state_rx).await; + assert_state_running(&mut state_rx).await; + + let frame = create_test_video_frame(64, 64, PixelFormat::Rgba8, 128); + input_tx.send(Packet::Video(frame)).await.unwrap(); + + drop(input_tx); + assert_state_stopped(&mut state_rx).await; + node_handle.await.unwrap().unwrap(); + + let output_packets = mock_sender.get_packets_for_pin("out").await; + assert_eq!(output_packets.len(), 1); + + if let Packet::Video(out_frame) = &output_packets[0] { + assert_eq!(out_frame.pixel_format, PixelFormat::Nv12); + assert_eq!(out_frame.width, 64); + assert_eq!(out_frame.height, 64); + assert!(!out_frame.data().is_empty()); + } else { + panic!("Expected Video packet"); + } + } + + #[tokio::test] + async fn test_rgba8_to_nv12_roundtrip() { + // Convert RGBA8 → NV12 via the raw buf functions and then + // NV12 → RGBA8, verifying pixel values are within tolerance. + let width: u32 = 64; + let height: u32 = 64; + let w = width as usize; + let h = height as usize; + + // Create a test RGBA8 buffer with known values. + let mut rgba = vec![0u8; w * h * 4]; + for y in 0..h { + for x in 0..w { + let off = (y * w + x) * 4; + rgba[off] = 200; // R + rgba[off + 1] = 100; // G + rgba[off + 2] = 50; // B + rgba[off + 3] = 255; // A + } + } + + // RGBA8 → NV12 + let chroma_w = w.div_ceil(2); + let chroma_h = h.div_ceil(2); + let nv12_size = w * h + chroma_w * 2 * chroma_h; + let mut nv12 = vec![0u8; nv12_size]; + rgba8_to_nv12_buf(&rgba, width, height, &mut nv12); + + // NV12 → RGBA8 + let mut decoded = vec![0u8; w * h * 4]; + nv12_to_rgba8_buf(&nv12, width, height, &mut decoded); + + // Verify pixel values are within tolerance (YUV roundtrip has some loss). + for y in 0..h { + for x in 0..w { + let off = (y * w + x) * 4; + let dr = (decoded[off] as i32 - 200).unsigned_abs(); + let dg = (decoded[off + 1] as i32 - 100).unsigned_abs(); + let db = (decoded[off + 2] as i32 - 50).unsigned_abs(); + assert!(dr <= 3, "R channel diff too large at ({x},{y}): {dr}"); + assert!(dg <= 3, "G channel diff too large at ({x},{y}): {dg}"); + assert!(db <= 3, "B channel diff too large at ({x},{y}): {db}"); + } + } + } + + #[tokio::test] + async fn test_unsupported_conversion_pair() { + let (input_tx, input_rx) = mpsc::channel(10); + let mut inputs = HashMap::new(); + inputs.insert("in".to_string(), input_rx); + + let (context, mock_sender, mut state_rx) = create_test_context(inputs, 10); + + // Target is I420, but we'll send NV12 — unsupported pair. + let node = PixelConvertNode::new(&PixelConvertConfig { output_format: "i420".to_string() }) + .unwrap(); + + let node_handle = tokio::spawn(async move { Box::new(node).run(context).await }); + + assert_state_initializing(&mut state_rx).await; + assert_state_running(&mut state_rx).await; + + let frame = create_test_video_frame(64, 64, PixelFormat::Nv12, 128); + input_tx.send(Packet::Video(frame)).await.unwrap(); + + drop(input_tx); + assert_state_stopped(&mut state_rx).await; + node_handle.await.unwrap().unwrap(); + + // The unsupported conversion should produce no output (error logged). + let output_packets = mock_sender.get_packets_for_pin("out").await; + assert_eq!(output_packets.len(), 0, "Unsupported conversion should produce no output"); + } + + #[test] + fn test_invalid_output_format() { + let result = + PixelConvertNode::new(&PixelConvertConfig { output_format: "yuv444".to_string() }); + assert!(result.is_err()); + } +} diff --git a/crates/nodes/src/video/vp9.rs b/crates/nodes/src/video/vp9.rs new file mode 100644 index 00000000..f35677e7 --- /dev/null +++ b/crates/nodes/src/video/vp9.rs @@ -0,0 +1,1369 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! VP9 video codec nodes (CPU). + +use async_trait::async_trait; +use bytes::Bytes; +use opentelemetry::{global, KeyValue}; +use schemars::JsonSchema; +use serde::Deserialize; +use std::borrow::Cow; +use std::ffi::CStr; +use std::sync::Arc; +use std::time::Instant; +use streamkit_core::stats::NodeStatsTracker; +use streamkit_core::types::{ + EncodedVideoFormat, Packet, PacketMetadata, PacketType, PixelFormat, VideoCodec, VideoFormat, + VideoFrame, VideoLayout, +}; +use streamkit_core::{ + config_helpers, get_codec_channel_capacity, packet_helpers, state_helpers, InputPin, + NodeContext, NodeRegistry, OutputPin, PinCardinality, PooledVideoData, ProcessorNode, + StreamKitError, VideoFramePool, +}; +use tokio::sync::mpsc; +use vpx::vp8e_enc_control_id::{VP8E_SET_CPUUSED, VP8E_SET_ENABLEAUTOALTREF}; +use vpx::vpx_codec_cx_pkt_kind::VPX_CODEC_CX_FRAME_PKT; +use vpx::vpx_img_fmt::{VPX_IMG_FMT_I420, VPX_IMG_FMT_NV12}; +use vpx::vpx_kf_mode::VPX_KF_AUTO; +use vpx_sys as vpx; + +const VP9_TIMEBASE_DEN: i32 = 1_000_000; +const VP9_CONTENT_TYPE: &str = "video/vp9"; + +// libvpx ABI values are macros in vpx headers; libvpx-sys doesn't expose them. +// Values are derived from /usr/include/vpx headers (VPX_IMAGE/VPX_CODEC/VPX_ENCODER ABI). +const VPX_IMAGE_ABI_VERSION: i32 = 5; +const VPX_CODEC_ABI_VERSION: i32 = 4 + VPX_IMAGE_ABI_VERSION; +const VPX_DECODER_ABI_VERSION: i32 = 3 + VPX_CODEC_ABI_VERSION; +const VPX_EXT_RATECTRL_ABI_VERSION: i32 = 1; +const VPX_ENCODER_ABI_VERSION: i32 = 15 + VPX_CODEC_ABI_VERSION + VPX_EXT_RATECTRL_ABI_VERSION; + +const VPX_EFLAG_FORCE_KF: vpx::vpx_enc_frame_flags_t = 1; +const VPX_FRAME_IS_KEY: u32 = 0x1; +const VPX_DL_BEST_QUALITY: u64 = 0; +const VPX_DL_GOOD_QUALITY: u64 = 1_000_000; +const VPX_DL_REALTIME: u64 = 1; +const VPX_CODEC_CAP_ENCODER: u32 = 0x2; + +const VP9_DEFAULT_BITRATE_KBPS: u32 = 2500; +const VP9_DEFAULT_KF_INTERVAL: u32 = 120; +const VP9_DEFAULT_THREADS: u32 = 2; + +#[derive(Deserialize, Debug, JsonSchema, Clone)] +#[serde(default)] +pub struct Vp9DecoderConfig { + pub threads: u32, +} + +impl Default for Vp9DecoderConfig { + fn default() -> Self { + Self { threads: VP9_DEFAULT_THREADS } + } +} + +/// Controls the CPU time the VP9 encoder is allowed to spend per frame. +/// +/// Maps to the libvpx `deadline` parameter in `vpx_codec_encode`. +#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Eq, JsonSchema)] +#[serde(rename_all = "snake_case")] +#[derive(Default)] +pub enum Vp9EncoderDeadline { + /// Real-time encoding – lowest latency, may sacrifice quality (VPX_DL_REALTIME). + #[default] + Realtime, + /// Good quality – allows up to ~1 second per frame (VPX_DL_GOOD_QUALITY). + GoodQuality, + /// Best quality – unlimited time per frame (VPX_DL_BEST_QUALITY). + BestQuality, +} + +impl Vp9EncoderDeadline { + const fn as_vpx_deadline(self) -> u64 { + match self { + Self::Realtime => VPX_DL_REALTIME, + Self::GoodQuality => VPX_DL_GOOD_QUALITY, + Self::BestQuality => VPX_DL_BEST_QUALITY, + } + } +} + +#[derive(Deserialize, Debug, JsonSchema, Clone)] +#[serde(default)] +pub struct Vp9EncoderConfig { + pub bitrate_kbps: u32, + pub keyframe_interval: u32, + pub threads: u32, + pub deadline: Vp9EncoderDeadline, +} + +impl Default for Vp9EncoderConfig { + fn default() -> Self { + Self { + bitrate_kbps: VP9_DEFAULT_BITRATE_KBPS, + keyframe_interval: VP9_DEFAULT_KF_INTERVAL, + threads: VP9_DEFAULT_THREADS, + deadline: Vp9EncoderDeadline::default(), + } + } +} + +pub struct Vp9DecoderNode { + config: Vp9DecoderConfig, +} + +impl Vp9DecoderNode { + #[allow(clippy::missing_errors_doc)] + pub const fn new(config: Vp9DecoderConfig) -> Result { + Ok(Self { config }) + } +} + +#[async_trait] +impl ProcessorNode for Vp9DecoderNode { + fn input_pins(&self) -> Vec { + vec![InputPin { + name: "in".to_string(), + accepts_types: vec![PacketType::EncodedVideo(EncodedVideoFormat { + codec: VideoCodec::Vp9, + bitstream_format: None, + codec_private: None, + profile: None, + level: None, + })], + cardinality: PinCardinality::One, + }] + } + + fn output_pins(&self) -> Vec { + vec![OutputPin { + name: "out".to_string(), + produces_type: PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::Nv12, + }), + cardinality: PinCardinality::Broadcast, + }] + } + + async fn run(self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { + let node_name = context.output_sender.node_name().to_string(); + state_helpers::emit_initializing(&context.state_tx, &node_name); + + tracing::info!("Vp9DecoderNode starting"); + let mut input_rx = context.take_input("in")?; + let video_pool = context.video_pool.clone(); + + let meter = global::meter("skit_nodes"); + let packets_processed_counter = meter.u64_counter("vp9_packets_processed").build(); + let decode_duration_histogram = meter + .f64_histogram("vp9_decode_duration") + .with_boundaries(streamkit_core::metrics::HISTOGRAM_BOUNDARIES_CODEC_PACKET.to_vec()) + .build(); + + let (decode_tx, mut decode_rx) = + mpsc::channel::<(Bytes, Option)>(get_codec_channel_capacity()); + let (result_tx, mut result_rx) = + mpsc::channel::>(get_codec_channel_capacity()); + + let decoder_threads = self.config.threads; + let decode_task = tokio::task::spawn_blocking(move || { + let mut decoder = match Vp9Decoder::new(decoder_threads) { + Ok(decoder) => decoder, + Err(err) => { + let _ = result_tx.blocking_send(Err(err)); + return; + }, + }; + + while let Some((data, metadata)) = decode_rx.blocking_recv() { + let decode_start_time = Instant::now(); + let result = decoder.decode_packet(&data, metadata, video_pool.as_ref()); + decode_duration_histogram.record(decode_start_time.elapsed().as_secs_f64(), &[]); + + match result { + Ok(frames) => { + for frame in frames { + if result_tx.blocking_send(Ok(frame)).is_err() { + return; + } + } + }, + Err(err) => { + let _ = result_tx.blocking_send(Err(err)); + }, + } + } + }); + + state_helpers::emit_running(&context.state_tx, &node_name); + + let mut stats_tracker = NodeStatsTracker::new(node_name.clone(), context.stats_tx.clone()); + let batch_size = context.batch_size; + + let decode_tx_clone = decode_tx.clone(); + let mut input_task = tokio::spawn(async move { + loop { + let Some(first_packet) = input_rx.recv().await else { + break; + }; + + let packet_batch = + packet_helpers::batch_packets_greedy(first_packet, &mut input_rx, batch_size); + + for packet in packet_batch { + if let Packet::Binary { data, metadata, .. } = packet { + if decode_tx_clone.send((data, metadata)).await.is_err() { + tracing::error!( + "Vp9DecoderNode decode task has shut down unexpectedly" + ); + return; + } + } + } + } + tracing::info!("Vp9DecoderNode input stream closed"); + }); + + codec_forward_loop( + &mut context, + &mut result_rx, + &mut input_task, + decode_task, + decode_tx, + &packets_processed_counter, + &mut stats_tracker, + Packet::Video, + "Vp9DecoderNode", + ) + .await; + + state_helpers::emit_stopped(&context.state_tx, &node_name, "input_closed"); + tracing::info!("Vp9DecoderNode finished"); + Ok(()) + } +} + +pub struct Vp9EncoderNode { + config: Vp9EncoderConfig, +} + +impl Vp9EncoderNode { + #[allow(clippy::missing_errors_doc)] + pub const fn new(config: Vp9EncoderConfig) -> Result { + Ok(Self { config }) + } +} + +#[async_trait] +impl ProcessorNode for Vp9EncoderNode { + fn input_pins(&self) -> Vec { + vec![InputPin { + name: "in".to_string(), + accepts_types: vec![ + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::I420, + }), + PacketType::RawVideo(VideoFormat { + width: None, + height: None, + pixel_format: PixelFormat::Nv12, + }), + ], + cardinality: PinCardinality::One, + }] + } + + fn output_pins(&self) -> Vec { + vec![OutputPin { + name: "out".to_string(), + produces_type: PacketType::EncodedVideo(EncodedVideoFormat { + codec: VideoCodec::Vp9, + bitstream_format: None, + codec_private: None, + profile: None, + level: None, + }), + cardinality: PinCardinality::Broadcast, + }] + } + + fn content_type(&self) -> Option { + Some(VP9_CONTENT_TYPE.to_string()) + } + + async fn run(self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { + let node_name = context.output_sender.node_name().to_string(); + state_helpers::emit_initializing(&context.state_tx, &node_name); + + tracing::info!("Vp9EncoderNode starting"); + let mut input_rx = context.take_input("in")?; + + let meter = global::meter("skit_nodes"); + let packets_processed_counter = meter.u64_counter("vp9_packets_processed").build(); + let encode_duration_histogram = meter + .f64_histogram("vp9_encode_duration") + .with_boundaries(streamkit_core::metrics::HISTOGRAM_BOUNDARIES_CODEC_PACKET.to_vec()) + .build(); + + let (encode_tx, mut encode_rx) = + mpsc::channel::<(VideoFrame, Option)>(get_codec_channel_capacity()); + let (result_tx, mut result_rx) = + mpsc::channel::>(get_codec_channel_capacity()); + + let encoder_config = self.config; + let encode_task = tokio::task::spawn_blocking(move || { + let mut encoder: Option = None; + let mut current_dimensions: Option<(u32, u32)> = None; + + while let Some((frame, metadata)) = encode_rx.blocking_recv() { + if frame.pixel_format == PixelFormat::Rgba8 { + let _ = + result_tx.blocking_send(Err("VP9 encoder requires NV12 or I420 input; \ + insert a video::pixel_convert node upstream" + .to_string())); + continue; + } + let encode_frame = frame; + + let frame_dimensions = (encode_frame.width, encode_frame.height); + if current_dimensions != Some(frame_dimensions) { + match Vp9Encoder::new(encode_frame.width, encode_frame.height, &encoder_config) + { + Ok(new_encoder) => { + encoder = Some(new_encoder); + current_dimensions = Some(frame_dimensions); + }, + Err(err) => { + let _ = result_tx.blocking_send(Err(err)); + continue; + }, + } + } + + let Some(encoder) = encoder.as_mut() else { + let _ = result_tx.blocking_send(Err("VP9 encoder not initialized".to_string())); + continue; + }; + + let encode_start_time = Instant::now(); + let result = encoder.encode_frame(&encode_frame, metadata); + encode_duration_histogram.record(encode_start_time.elapsed().as_secs_f64(), &[]); + + match result { + Ok(packets) => { + for packet in packets { + if result_tx.blocking_send(Ok(packet)).is_err() { + return; + } + } + }, + Err(err) => { + let _ = result_tx.blocking_send(Err(err)); + }, + } + } + + if let Some(encoder) = encoder.as_mut() { + match encoder.flush() { + Ok(packets) => { + for packet in packets { + if result_tx.blocking_send(Ok(packet)).is_err() { + return; + } + } + }, + Err(err) => { + let _ = result_tx.blocking_send(Err(err)); + }, + } + } + }); + + state_helpers::emit_running(&context.state_tx, &node_name); + + let mut stats_tracker = NodeStatsTracker::new(node_name.clone(), context.stats_tx.clone()); + let batch_size = context.batch_size; + + let encode_tx_clone = encode_tx.clone(); + let mut input_task = tokio::spawn(async move { + loop { + let Some(first_packet) = input_rx.recv().await else { + break; + }; + + let packet_batch = + packet_helpers::batch_packets_greedy(first_packet, &mut input_rx, batch_size); + + for packet in packet_batch { + if let Packet::Video(mut frame) = packet { + let metadata = frame.metadata.take(); + if encode_tx_clone.send((frame, metadata)).await.is_err() { + tracing::error!( + "Vp9EncoderNode encode task has shut down unexpectedly" + ); + return; + } + } + } + } + tracing::info!("Vp9EncoderNode input stream closed"); + }); + + codec_forward_loop( + &mut context, + &mut result_rx, + &mut input_task, + encode_task, + encode_tx, + &packets_processed_counter, + &mut stats_tracker, + |encoded| Packet::Binary { + data: encoded.data, + content_type: Some(Cow::Borrowed(VP9_CONTENT_TYPE)), + metadata: encoded.metadata, + }, + "Vp9EncoderNode", + ) + .await; + + state_helpers::emit_stopped(&context.state_tx, &node_name, "input_closed"); + tracing::info!("Vp9EncoderNode finished"); + Ok(()) + } +} + +/// Shared select-loop that forwards codec results to the output sender. +/// +/// Handles three concurrent events: +/// 1. Results arriving from the blocking codec task. +/// 2. Shutdown control messages. +/// 3. Input task completion (triggers drain of remaining results). +/// +/// `to_packet` converts a codec-specific result `T` into a [`Packet`]. +#[allow(clippy::too_many_arguments)] +async fn codec_forward_loop( + context: &mut NodeContext, + result_rx: &mut mpsc::Receiver>, + input_task: &mut tokio::task::JoinHandle<()>, + codec_task: tokio::task::JoinHandle<()>, + codec_tx: mpsc::Sender, + counter: &opentelemetry::metrics::Counter, + stats: &mut NodeStatsTracker, + to_packet: impl Fn(T) -> Packet, + label: &str, +) { + /// Forwards a single successful codec result to the output sender. + /// Returns `true` if the output channel is closed (caller should break). + async fn forward_one( + packet: Packet, + context: &mut NodeContext, + counter: &opentelemetry::metrics::Counter, + stats: &mut NodeStatsTracker, + ) -> bool { + counter.add(1, &[KeyValue::new("status", "ok")]); + stats.received(); + if context.output_sender.send("out", packet).await.is_err() { + tracing::debug!("Output channel closed, stopping node"); + return true; + } + stats.sent(); + stats.maybe_send(); + false + } + + /// Handles a codec error result by updating counters and logging. + fn handle_error( + err: &str, + counter: &opentelemetry::metrics::Counter, + stats: &mut NodeStatsTracker, + label: &str, + ) { + counter.add(1, &[KeyValue::new("status", "error")]); + stats.received(); + stats.errored(); + stats.maybe_send(); + tracing::warn!("{label} codec error: {err}"); + } + + loop { + tokio::select! { + maybe_result = result_rx.recv() => { + match maybe_result { + Some(Ok(item)) => { + if forward_one(to_packet(item), context, counter, stats).await { + break; + } + } + Some(Err(err)) => handle_error(&err, counter, stats, label), + None => break, + } + } + Some(control_msg) = context.control_rx.recv() => { + if matches!(control_msg, streamkit_core::control::NodeControlMessage::Shutdown) { + tracing::info!("{label} received shutdown signal"); + // NOTE: Aborting the input task and dropping codec_tx causes + // the codec thread to exit/flush, but because we break out + // here those flushed results are never sent downstream. + // Data loss on explicit shutdown is acceptable. + input_task.abort(); + codec_task.abort(); + drop(codec_tx); + break; + } + } + _ = &mut *input_task => { + drop(codec_tx); + while let Some(maybe_result) = result_rx.recv().await { + match maybe_result { + Ok(item) => { + if forward_one(to_packet(item), context, counter, stats).await { + break; + } + } + Err(err) => handle_error(&err, counter, stats, label), + } + } + break; + } + } + } + + codec_task.abort(); + let _ = codec_task.await; +} + +struct EncodedPacket { + data: Bytes, + metadata: Option, +} + +struct Vp9Decoder { + ctx: vpx::vpx_codec_ctx_t, +} + +impl Vp9Decoder { + fn new(threads: u32) -> Result { + let iface = unsafe { + // SAFETY: libvpx returns a static codec interface pointer. + vpx::vpx_codec_vp9_dx() + }; + if iface.is_null() { + return Err("VP9 decoder interface not available".to_string()); + } + + let mut ctx = unsafe { + // SAFETY: vpx_codec_ctx_t is a plain C struct that can be zero-initialized. + std::mem::zeroed() + }; + let cfg = vpx::vpx_codec_dec_cfg_t { threads, w: 0, h: 0 }; + + let res = unsafe { + // SAFETY: ctx and cfg are valid and iface is non-null. + vpx::vpx_codec_dec_init_ver( + &raw mut ctx, + iface, + &raw const cfg, + 0, + VPX_DECODER_ABI_VERSION, + ) + }; + check_vpx(res, &raw mut ctx, "VP9 decoder init")?; + + Ok(Self { ctx }) + } + + fn decode_packet( + &mut self, + data: &[u8], + metadata: Option, + video_pool: Option<&Arc>, + ) -> Result, String> { + let data_len = + u32::try_from(data.len()).map_err(|_| "VP9 packet too large for libvpx".to_string())?; + let res = unsafe { + // SAFETY: libvpx expects a valid buffer for the duration of the call. + vpx::vpx_codec_decode( + &raw mut self.ctx, + data.as_ptr(), + data_len, + std::ptr::null_mut(), + 0, + ) + }; + check_vpx(res, &raw mut self.ctx, "VP9 decode")?; + + // Most VP9 packets produce exactly one frame; pre-allocate for that + // common case to avoid a heap allocation + realloc in the hot path. + let mut frames = Vec::with_capacity(1); + let mut iter: vpx::vpx_codec_iter_t = std::ptr::null_mut(); + let mut remaining_metadata = metadata; + + loop { + let image_ptr = unsafe { + // SAFETY: iter is managed by libvpx and image_ptr is valid until next call. + vpx::vpx_codec_get_frame(&raw mut self.ctx, &raw mut iter) + }; + if image_ptr.is_null() { + break; + } + + let image = unsafe { + // SAFETY: image_ptr is non-null and points to a valid vpx_image_t. + &*image_ptr + }; + + // Peek ahead: if another frame follows, clone metadata; otherwise move it. + let next_ptr = unsafe { + let mut peek_iter = iter; + vpx::vpx_codec_get_frame(&raw mut self.ctx, &raw mut peek_iter) + }; + let meta = if next_ptr.is_null() { + remaining_metadata.take() + } else { + remaining_metadata.clone() + }; + + let frame = copy_vpx_image(image, meta, video_pool)?; + frames.push(frame); + } + + Ok(frames) + } +} + +impl Drop for Vp9Decoder { + fn drop(&mut self) { + unsafe { + // SAFETY: ctx is initialized by libvpx and must be destroyed exactly once. + vpx::vpx_codec_destroy(&raw mut self.ctx); + } + } +} + +struct Vp9Encoder { + ctx: vpx::vpx_codec_ctx_t, + next_pts: i64, + deadline: u64, +} + +impl Vp9Encoder { + fn new(width: u32, height: u32, config: &Vp9EncoderConfig) -> Result { + let iface = unsafe { + // SAFETY: libvpx returns a static codec interface pointer. + vpx::vpx_codec_vp9_cx() + }; + if iface.is_null() { + return Err("VP9 encoder interface not available".to_string()); + } + let caps = unsafe { + // SAFETY: iface is non-null. + vpx::vpx_codec_get_caps(iface) + }; + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + if (caps as u32 & VPX_CODEC_CAP_ENCODER) == 0 { + return Err("libvpx does not expose VP9 encoder capabilities".to_string()); + } + + let mut ctx = unsafe { + // SAFETY: vpx_codec_ctx_t is a plain C struct that can be zero-initialized. + std::mem::zeroed() + }; + let mut cfg = std::mem::MaybeUninit::::uninit(); + let res = unsafe { + // SAFETY: cfg is valid for initialization and iface is non-null. + vpx::vpx_codec_enc_config_default(iface, cfg.as_mut_ptr(), 0) + }; + check_vpx(res, std::ptr::null_mut(), "VP9 encoder config")?; + let mut cfg = unsafe { + // SAFETY: vpx_codec_enc_config_default initializes cfg on success. + cfg.assume_init() + }; + + cfg.g_w = width; + cfg.g_h = height; + cfg.g_timebase.num = 1; + cfg.g_timebase.den = VP9_TIMEBASE_DEN; + cfg.rc_target_bitrate = config.bitrate_kbps.max(1); + cfg.g_threads = config.threads.max(1); + cfg.g_lag_in_frames = 0; + cfg.kf_mode = VPX_KF_AUTO; + cfg.kf_min_dist = 0; + cfg.kf_max_dist = config.keyframe_interval.max(1); + + let res = unsafe { + // SAFETY: ctx and cfg are valid and iface is non-null. + vpx::vpx_codec_enc_init_ver( + &raw mut ctx, + iface, + &raw const cfg, + 0, + VPX_ENCODER_ABI_VERSION, + ) + }; + if let Err(err) = check_vpx(res, &raw mut ctx, "VP9 encoder init") { + let cfg_summary = format!( + "w={width} h={height} timebase=1/{den} bitrate_kbps={} threads={} lag={} kf_max={}", + cfg.rc_target_bitrate, + cfg.g_threads, + cfg.g_lag_in_frames, + cfg.kf_max_dist, + den = cfg.g_timebase.den + ); + return Err(format!("{err} (cfg: {cfg_summary})")); + } + + unsafe { + // SAFETY: Control calls are valid after encoder initialization. + set_codec_control(&raw mut ctx, VP8E_SET_ENABLEAUTOALTREF as i32, 0)?; + set_codec_control(&raw mut ctx, VP8E_SET_CPUUSED as i32, 6)?; + } + + Ok(Self { ctx, next_pts: 0, deadline: config.deadline.as_vpx_deadline() }) + } + + fn encode_frame( + &mut self, + frame: &VideoFrame, + metadata: Option, + ) -> Result, String> { + let vpx_fmt = match frame.pixel_format { + PixelFormat::I420 => VPX_IMG_FMT_I420, + PixelFormat::Nv12 => VPX_IMG_FMT_NV12, + other @ PixelFormat::Rgba8 => { + return Err(format!("VP9 encoder expects I420 or NV12 input, got {other:?}")); + }, + }; + + let layout = frame.layout(); + if frame.data_len() < layout.total_bytes() { + return Err(format!( + "VP9 encoder expected {} bytes, got {}", + layout.total_bytes(), + frame.data_len() + )); + } + let expected_layout = VideoLayout::aligned( + frame.width, + frame.height, + frame.pixel_format, + layout.stride_align(), + ); + if layout != expected_layout { + return Err(format!( + "VP9 encoder requires the canonical aligned {:?} layout", + frame.pixel_format + )); + } + + let mut image = std::mem::MaybeUninit::::uninit(); + let image_ptr = unsafe { + // SAFETY: frame data is valid for the duration of this call. + vpx::vpx_img_wrap( + image.as_mut_ptr(), + vpx_fmt, + frame.width, + frame.height, + layout.stride_align(), + frame.data.as_slice().as_ptr().cast_mut(), + ) + }; + if image_ptr.is_null() { + return Err(format!("Failed to wrap {:?} frame for VP9 encoder", frame.pixel_format)); + } + let image = unsafe { + // SAFETY: vpx_img_wrap initialized image on success. + image.assume_init() + }; + + let (pts, duration) = self.next_pts(metadata.as_ref()); + let mut flags: vpx::vpx_enc_frame_flags_t = 0; + if metadata.as_ref().and_then(|meta| meta.keyframe).unwrap_or(false) { + flags |= VPX_EFLAG_FORCE_KF; + } + + let res = unsafe { + // SAFETY: image is initialized and ctx is ready for encode. + vpx::vpx_codec_encode( + &raw mut self.ctx, + &raw const image, + pts, + duration, + flags, + self.deadline, + ) + }; + check_vpx(res, &raw mut self.ctx, "VP9 encode")?; + + let packets = self.drain_packets(metadata); + + Ok(packets) + } + + fn flush(&mut self) -> Result, String> { + let mut output = Vec::new(); + for _ in 0..16 { + let res = unsafe { + // SAFETY: Passing a null image flushes delayed frames. + vpx::vpx_codec_encode(&raw mut self.ctx, std::ptr::null(), 0, 0, 0, self.deadline) + }; + check_vpx(res, &raw mut self.ctx, "VP9 encode flush")?; + + let mut packets = self.drain_packets(None); + if packets.is_empty() { + break; + } + output.append(&mut packets); + } + + Ok(output) + } + + fn drain_packets(&mut self, metadata: Option) -> Vec { + let mut packets = Vec::new(); + let mut iter: vpx::vpx_codec_iter_t = std::ptr::null_mut(); + let mut remaining_metadata = metadata; + loop { + let packet_ptr = unsafe { + // SAFETY: iter is managed by libvpx and packet_ptr is valid until next call. + vpx::vpx_codec_get_cx_data(&raw mut self.ctx, &raw mut iter) + }; + if packet_ptr.is_null() { + break; + } + + let packet = unsafe { + // SAFETY: packet_ptr is non-null and points to a valid vpx_codec_cx_pkt_t. + &*packet_ptr + }; + + if packet.kind != VPX_CODEC_CX_FRAME_PKT { + continue; + } + + let frame_pkt = unsafe { + // SAFETY: Union access for frame packet data. + packet.data.frame + }; + + let data: Bytes = unsafe { + // SAFETY: frame_pkt.buf is valid for frame_pkt.sz bytes. + // Copy into Bytes directly so the downstream Packet::Binary + // doesn't need a second Vec → Bytes conversion. + #[allow(clippy::cast_possible_truncation)] + Bytes::copy_from_slice(std::slice::from_raw_parts( + frame_pkt.buf as *const u8, + frame_pkt.sz as usize, + )) + }; + + let is_keyframe = (frame_pkt.flags as u32 & VPX_FRAME_IS_KEY) != 0; + + // Peek ahead: if another frame packet follows, clone metadata; otherwise move it. + let next_ptr = unsafe { + let mut peek_iter = iter; + vpx::vpx_codec_get_cx_data(&raw mut self.ctx, &raw mut peek_iter) + }; + let meta = if next_ptr.is_null() { + remaining_metadata.take() + } else { + remaining_metadata.clone() + }; + + let output_metadata = merge_keyframe_metadata( + meta, + is_keyframe, + frame_pkt.pts, + frame_pkt.duration as u64, + ); + + packets.push(EncodedPacket { data, metadata: Some(output_metadata) }); + } + + packets + } + + fn next_pts(&mut self, metadata: Option<&PacketMetadata>) -> (i64, u64) { + // Default to 1µs rather than 0 so libvpx rate-control heuristics + // always see a non-zero duration. The PTS advance fallback already + // uses `pts + 1`, so this keeps the two paths consistent. + let duration = metadata.and_then(|meta| meta.duration_us).unwrap_or(1); + + let pts = + metadata.and_then(|meta| meta.timestamp_us).map_or(self.next_pts, u64::cast_signed); + + self.next_pts = if duration > 0 { pts + duration.cast_signed() } else { pts + 1 }; + (pts, duration) + } +} + +impl Drop for Vp9Encoder { + fn drop(&mut self) { + unsafe { + // SAFETY: ctx is initialized by libvpx and must be destroyed exactly once. + vpx::vpx_codec_destroy(&raw mut self.ctx); + } + } +} + +fn check_vpx( + res: vpx::vpx_codec_err_t, + ctx: *mut vpx::vpx_codec_ctx_t, + context: &str, +) -> Result<(), String> { + if res == vpx::VPX_CODEC_OK { + return Ok(()); + } + + let err = vpx_error(ctx, res); + let detail = if ctx.is_null() { + None + } else { + let detail_ptr = unsafe { + // SAFETY: libvpx returns a NUL-terminated error detail string. + vpx::vpx_codec_error_detail(ctx) + }; + if detail_ptr.is_null() { + None + } else { + Some(unsafe { + // SAFETY: detail_ptr is a valid C string. + CStr::from_ptr(detail_ptr).to_string_lossy().into_owned() + }) + } + }; + + detail.map_or_else( + || Err(format!("{context}: {err}")), + |detail| Err(format!("{context}: {err} ({detail})")), + ) +} + +unsafe fn set_codec_control( + ctx: *mut vpx::vpx_codec_ctx_t, + ctrl_id: i32, + value: i32, +) -> Result<(), String> { + let res = vpx::vpx_codec_control_(ctx, ctrl_id, value); + check_vpx(res, ctx, "VP9 codec control") +} + +fn vpx_error(ctx: *mut vpx::vpx_codec_ctx_t, err: vpx::vpx_codec_err_t) -> String { + unsafe { + // SAFETY: libvpx returns a NUL-terminated error string. + let msg_ptr = if ctx.is_null() { + vpx::vpx_codec_err_to_string(err) + } else { + vpx::vpx_codec_error(ctx) + }; + if msg_ptr.is_null() { + "libvpx error".to_string() + } else { + CStr::from_ptr(msg_ptr).to_string_lossy().into_owned() + } + } +} + +/// Copy a decoded I420 vpx_image into an NV12 `VideoFrame`. +/// +/// libvpx always decodes VP9 to I420 (three separate Y, U, V planes). +/// We convert to NV12 on the fly by copying the Y plane as-is and +/// interleaving the U and V planes into a single UV plane. +/// This is a cheap operation — just zipping two half-size planes. +fn copy_vpx_image( + image: &vpx::vpx_image_t, + metadata: Option, + video_pool: Option<&Arc>, +) -> Result { + if image.fmt != VPX_IMG_FMT_I420 { + return Err("VP9 decoder produced non-I420 frame".to_string()); + } + + let width = image.d_w; + let height = image.d_h; + if width == 0 || height == 0 { + return Err("VP9 decoder produced empty frame".to_string()); + } + + // Output layout is NV12 (Y + interleaved UV). + let nv12_layout = VideoLayout::packed(width, height, PixelFormat::Nv12); + let mut data = video_pool.map_or_else( + || PooledVideoData::from_vec(vec![0u8; nv12_layout.total_bytes()]), + |pool| pool.get(nv12_layout.total_bytes()), + ); + let data_slice = data.as_mut_slice(); + + let nv12_planes = nv12_layout.planes(); + let y_plane = nv12_planes[0]; + let uv_plane = nv12_planes[1]; + + // ── Copy Y plane (plane 0) — identical for I420 and NV12 ── + let y_src_ptr = image.planes[0]; + if y_src_ptr.is_null() { + return Err("VP9 decoder returned null Y plane".to_string()); + } + copy_plane( + &mut data_slice[y_plane.offset..y_plane.offset + y_plane.stride * y_plane.height as usize], + y_plane.stride, + y_src_ptr, + image.stride[0], + width as usize, + height as usize, + )?; + + // ── Interleave U + V into NV12's single UV plane ── + let u_src_ptr = image.planes[1]; + let v_src_ptr = image.planes[2]; + if u_src_ptr.is_null() || v_src_ptr.is_null() { + return Err("VP9 decoder returned null chroma plane".to_string()); + } + + let chroma_w = (width as usize).div_ceil(2); + let chroma_h = uv_plane.height as usize; + + if image.stride[1] <= 0 || image.stride[2] <= 0 { + return Err("Invalid source stride for VP9 chroma plane".to_string()); + } + + #[allow(clippy::cast_sign_loss)] + let u_src_stride = image.stride[1] as usize; + #[allow(clippy::cast_sign_loss)] + let v_src_stride = image.stride[2] as usize; + + for row in 0..chroma_h { + let u_row = unsafe { + // SAFETY: u_src_ptr is valid with u_src_stride bytes per row. + std::slice::from_raw_parts(u_src_ptr.add(row * u_src_stride), chroma_w) + }; + let v_row = unsafe { + // SAFETY: v_src_ptr is valid with v_src_stride bytes per row. + std::slice::from_raw_parts(v_src_ptr.add(row * v_src_stride), chroma_w) + }; + let dst_start = uv_plane.offset + row * uv_plane.stride; + for col in 0..chroma_w { + data_slice[dst_start + col * 2] = u_row[col]; + data_slice[dst_start + col * 2 + 1] = v_row[col]; + } + } + + VideoFrame::from_pooled(width, height, PixelFormat::Nv12, data, metadata) + .map_err(|e| e.to_string()) +} + +fn copy_plane( + dst: &mut [u8], + dst_stride: usize, + src_ptr: *const u8, + src_stride: i32, + width: usize, + height: usize, +) -> Result<(), String> { + if src_stride <= 0 { + return Err("Invalid source stride for VP9 plane".to_string()); + } + #[allow(clippy::cast_sign_loss)] + let src_stride = src_stride as usize; + + for row in 0..height { + let src_row = unsafe { + // SAFETY: src_ptr points to a valid plane with src_stride bytes per row. + std::slice::from_raw_parts(src_ptr.add(row * src_stride), width) + }; + let dst_start = row * dst_stride; + let dst_end = dst_start + width; + if dst_end > dst.len() { + return Err("VP9 plane copy overflow".to_string()); + } + dst[dst_start..dst_end].copy_from_slice(src_row); + } + + Ok(()) +} + +const fn merge_keyframe_metadata( + metadata: Option, + keyframe: bool, + pts: i64, + duration: u64, +) -> PacketMetadata { + match metadata { + Some(mut meta) => { + meta.keyframe = Some(keyframe); + meta + }, + None => PacketMetadata { + timestamp_us: if pts >= 0 { Some(pts.cast_unsigned()) } else { None }, + duration_us: if duration > 0 { Some(duration) } else { None }, + sequence: None, + keyframe: Some(keyframe), + }, + } +} + +#[cfg(test)] +fn vp9_encoder_available() -> bool { + let iface = unsafe { vpx::vpx_codec_vp9_cx() }; + if iface.is_null() { + return false; + } + let caps = unsafe { vpx::vpx_codec_get_caps(iface) }; + u32::try_from(caps).is_ok_and(|caps_u32| (caps_u32 & VPX_CODEC_CAP_ENCODER) != 0) +} + +use schemars::schema_for; +use streamkit_core::registry::StaticPins; + +#[allow(clippy::expect_used, clippy::missing_panics_doc)] +pub fn register_vp9_nodes(registry: &mut NodeRegistry) { + let default_decoder = Vp9DecoderNode::new(Vp9DecoderConfig::default()) + .expect("default VP9 decoder config should be valid"); + registry.register_static_with_description( + "video::vp9::decoder", + |params| { + let config = config_helpers::parse_config_optional(params)?; + Ok(Box::new(Vp9DecoderNode::new(config)?)) + }, + serde_json::to_value(schema_for!(Vp9DecoderConfig)) + .expect("Vp9DecoderConfig schema should serialize to JSON"), + StaticPins { inputs: default_decoder.input_pins(), outputs: default_decoder.output_pins() }, + vec!["video".to_string(), "codecs".to_string(), "vp9".to_string()], + false, + "Decodes VP9-compressed packets into raw NV12 video frames. \ + Use this before CPU compositing or analysis pipelines.", + ); + + let default_encoder = Vp9EncoderNode::new(Vp9EncoderConfig::default()) + .expect("default VP9 encoder config should be valid"); + registry.register_static_with_description( + "video::vp9::encoder", + |params| { + let config = config_helpers::parse_config_optional(params)?; + Ok(Box::new(Vp9EncoderNode::new(config)?)) + }, + serde_json::to_value(schema_for!(Vp9EncoderConfig)) + .expect("Vp9EncoderConfig schema should serialize to JSON"), + StaticPins { inputs: default_encoder.input_pins(), outputs: default_encoder.output_pins() }, + vec!["video".to_string(), "codecs".to_string(), "vp9".to_string()], + false, + "Encodes raw video frames (NV12 or I420) into VP9 packets for transport or container muxing. \ + Insert a video::pixel_convert node upstream if the source outputs RGBA8.", + ); +} + +#[cfg(test)] +#[allow(clippy::unwrap_used, clippy::expect_used, clippy::disallowed_macros)] +mod tests { + use super::*; + use crate::test_utils::{ + assert_state_initializing, assert_state_running, assert_state_stopped, create_test_context, + create_test_video_frame, + }; + use std::collections::{HashMap, HashSet}; + use std::ffi::CStr; + use std::os::raw::c_char; + use tokio::sync::mpsc; + + fn vpx_string(ptr: *const c_char) -> String { + if ptr.is_null() { + return "null".to_string(); + } + unsafe { + // SAFETY: libvpx returns NUL-terminated C strings. + CStr::from_ptr(ptr).to_string_lossy().into_owned() + } + } + + fn dump_vpx_info() { + let version = unsafe { + // SAFETY: libvpx returns static string pointers. + vpx_string(vpx::vpx_codec_version_str()) + }; + let extra = unsafe { vpx_string(vpx::vpx_codec_version_extra_str()) }; + let build = unsafe { vpx_string(vpx::vpx_codec_build_config()) }; + eprintln!("libvpx version: {version} {extra}"); + eprintln!("libvpx build config: {build}"); + } + + #[tokio::test] + async fn test_vp9_encode_decode_roundtrip() { + dump_vpx_info(); + if !vp9_encoder_available() { + eprintln!("Skipping VP9 encode/decode roundtrip: encoder not available in libvpx"); + return; + } + + let (enc_input_tx, enc_input_rx) = mpsc::channel(10); + let mut enc_inputs = HashMap::new(); + enc_inputs.insert("in".to_string(), enc_input_rx); + + let (enc_context, enc_sender, mut enc_state_rx) = create_test_context(enc_inputs, 10); + let encoder_config = Vp9EncoderConfig { + keyframe_interval: 1, + bitrate_kbps: 800, + threads: 1, + ..Default::default() + }; + let encoder = Vp9EncoderNode::new(encoder_config.clone()).unwrap(); + + // Debug probe: run a direct encode to surface libvpx details if packets are missing. + let mut probe_encoder = Vp9Encoder::new(64, 64, &encoder_config).unwrap(); + let mut probe_frame = create_test_video_frame(64, 64, PixelFormat::Nv12, 16); + probe_frame.metadata = Some(PacketMetadata { + timestamp_us: Some(1_000), + duration_us: Some(33_333), + sequence: Some(0), + keyframe: Some(true), + }); + match probe_encoder.encode_frame(&probe_frame, probe_frame.metadata.clone()) { + Ok(packets) => { + eprintln!("VP9 probe encode packets: {}", packets.len()); + if packets.is_empty() { + if let Ok(flushed) = probe_encoder.flush() { + eprintln!("VP9 probe flush packets: {}", flushed.len()); + } + let detail = unsafe { + // SAFETY: ctx is valid for the duration of the encoder. + vpx_string(vpx::vpx_codec_error_detail(&raw mut probe_encoder.ctx)) + }; + eprintln!("VP9 probe error detail: {detail}"); + } + }, + Err(err) => { + eprintln!("VP9 probe encode error: {err}"); + }, + } + + let enc_handle = tokio::spawn(async move { Box::new(encoder).run(enc_context).await }); + + assert_state_initializing(&mut enc_state_rx).await; + assert_state_running(&mut enc_state_rx).await; + + let mut expected_metadata = HashMap::new(); + for index in 0_u64..5 { + let timestamp = 1_000 + 33_333_u64 * index; + let duration: u64 = 33_333; + expected_metadata.insert(index, (timestamp, duration)); + + let mut frame = create_test_video_frame(64, 64, PixelFormat::Nv12, 16); + frame.metadata = Some(PacketMetadata { + timestamp_us: Some(timestamp), + duration_us: Some(duration), + sequence: Some(index), + keyframe: Some(true), + }); + enc_input_tx.send(Packet::Video(frame)).await.unwrap(); + } + drop(enc_input_tx); + + assert_state_stopped(&mut enc_state_rx).await; + enc_handle.await.unwrap().unwrap(); + + let encoded_packets = enc_sender.get_packets_for_pin("out").await; + assert!(!encoded_packets.is_empty(), "VP9 encoder produced no packets"); + let mut encoded_sequences = HashSet::new(); + + for packet in &encoded_packets { + let Packet::Binary { metadata, .. } = packet else { + continue; + }; + let meta = metadata.as_ref().expect("Encoded VP9 packet missing metadata"); + let seq = meta.sequence.expect("Encoded VP9 packet missing sequence"); + let (expected_ts, expected_dur) = expected_metadata + .get(&seq) + .copied() + .expect("Encoded VP9 packet has unexpected sequence"); + + assert_eq!( + meta.timestamp_us, + Some(expected_ts), + "Encoded VP9 packet timestamp mismatch" + ); + assert_eq!( + meta.duration_us, + Some(expected_dur), + "Encoded VP9 packet duration mismatch" + ); + encoded_sequences.insert(seq); + } + + assert_eq!( + encoded_sequences.len(), + expected_metadata.len(), + "Encoded VP9 packets did not cover all input frames" + ); + + let (dec_input_tx, dec_input_rx) = mpsc::channel(10); + let mut dec_inputs = HashMap::new(); + dec_inputs.insert("in".to_string(), dec_input_rx); + + let (dec_context, dec_sender, mut dec_state_rx) = create_test_context(dec_inputs, 10); + let decoder = Vp9DecoderNode::new(Vp9DecoderConfig::default()).unwrap(); + let dec_handle = tokio::spawn(async move { Box::new(decoder).run(dec_context).await }); + + assert_state_initializing(&mut dec_state_rx).await; + assert_state_running(&mut dec_state_rx).await; + + for packet in encoded_packets { + if let Packet::Binary { data, metadata, .. } = packet { + dec_input_tx + .send(Packet::Binary { + data, + content_type: Some(Cow::Borrowed(VP9_CONTENT_TYPE)), + metadata, + }) + .await + .unwrap(); + } + } + drop(dec_input_tx); + + assert_state_stopped(&mut dec_state_rx).await; + dec_handle.await.unwrap().unwrap(); + + let decoded_packets = dec_sender.get_packets_for_pin("out").await; + assert!(!decoded_packets.is_empty(), "VP9 decoder produced no frames"); + let mut decoded_sequences = HashSet::new(); + + for packet in decoded_packets { + match packet { + Packet::Video(frame) => { + assert_eq!(frame.width, 64); + assert_eq!(frame.height, 64); + assert_eq!(frame.pixel_format, PixelFormat::Nv12); + assert!(!frame.data().is_empty(), "Decoded frame should have data"); + + let meta = frame.metadata.as_ref().expect("Decoded VP9 frame missing metadata"); + let seq = meta.sequence.expect("Decoded VP9 frame missing sequence"); + let (expected_ts, expected_dur) = expected_metadata + .get(&seq) + .copied() + .expect("Decoded VP9 frame has unexpected sequence"); + + assert_eq!( + meta.timestamp_us, + Some(expected_ts), + "Decoded VP9 frame timestamp mismatch" + ); + assert_eq!( + meta.duration_us, + Some(expected_dur), + "Decoded VP9 frame duration mismatch" + ); + decoded_sequences.insert(seq); + }, + _ => panic!("Expected Video packet from VP9 decoder"), + } + } + + assert_eq!( + decoded_sequences.len(), + expected_metadata.len(), + "Decoded VP9 frames did not cover all input frames" + ); + } +} diff --git a/crates/plugin-native/Cargo.toml b/crates/plugin-native/Cargo.toml index 97e8086f..4adf6c1a 100644 --- a/crates/plugin-native/Cargo.toml +++ b/crates/plugin-native/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "streamkit-plugin-native" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Claudio Costa ", "StreamKit Contributors"] description = "Native plugin runtime for StreamKit" diff --git a/crates/plugin-wasm/Cargo.toml b/crates/plugin-wasm/Cargo.toml index 57f249f4..7fde40be 100644 --- a/crates/plugin-wasm/Cargo.toml +++ b/crates/plugin-wasm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "streamkit-plugin-wasm" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Claudio Costa ", "StreamKit Contributors"] description = "WASM plugin runtime for StreamKit" diff --git a/crates/plugin-wasm/src/conversions.rs b/crates/plugin-wasm/src/conversions.rs index f712c411..987abb58 100644 --- a/crates/plugin-wasm/src/conversions.rs +++ b/crates/plugin-wasm/src/conversions.rs @@ -6,7 +6,8 @@ use crate::wit_types; use bytes::Bytes; use std::sync::Arc; use streamkit_core::types::{ - AudioFormat as CoreAudioFormat, CustomEncoding, CustomPacketData, PacketType as CorePacketType, + AudioCodec, AudioFormat as CoreAudioFormat, CustomEncoding, CustomPacketData, + EncodedAudioFormat, PacketType as CorePacketType, }; impl TryFrom for streamkit_core::types::Packet { @@ -74,6 +75,7 @@ impl From for wit_types::Packet { data, }) }, + streamkit_core::types::Packet::Video(frame) => Self::Binary(frame.data.to_vec()), streamkit_core::types::Packet::Binary { data, .. } => Self::Binary(data.to_vec()), } } @@ -91,7 +93,10 @@ impl From<&wit_types::PacketType> for CorePacketType { wit_types::SampleFormat::S16Le => streamkit_core::types::SampleFormat::S16Le, }, }), - wit_types::PacketType::OpusAudio => Self::OpusAudio, + wit_types::PacketType::OpusAudio => Self::EncodedAudio(EncodedAudioFormat { + codec: AudioCodec::Opus, + codec_private: None, + }), wit_types::PacketType::Text => Self::Text, wit_types::PacketType::Binary => Self::Binary, wit_types::PacketType::Custom(type_id) => Self::Custom { type_id: type_id.clone() }, diff --git a/docs/src/content/docs/getting-started/installation.md b/docs/src/content/docs/getting-started/installation.md index 5da535ab..d7ad9de2 100644 --- a/docs/src/content/docs/getting-started/installation.md +++ b/docs/src/content/docs/getting-started/installation.md @@ -44,6 +44,7 @@ Optional: - `cargo-deny` (`cargo install cargo-deny`) for license checks in `just lint` - `reuse` (`pip3 install --user reuse`) for SPDX license header checks in `just lint` (note: the apt package is too old) - `clang` and `libclang-dev` (`sudo apt install clang libclang-dev`) for building native ML plugins (e.g. whisper, sensevoice) +- `libvpx` + `pkg-config` if building with `--features video` (VP9 nodes) ### Build Steps diff --git a/docs/src/content/docs/reference/nodes/audio-opus-decoder.md b/docs/src/content/docs/reference/nodes/audio-opus-decoder.md index bf334dcd..12b2f5e0 100644 --- a/docs/src/content/docs/reference/nodes/audio-opus-decoder.md +++ b/docs/src/content/docs/reference/nodes/audio-opus-decoder.md @@ -16,7 +16,7 @@ Decodes Opus-compressed audio packets into raw PCM samples. Opus is the preferre ## Pins ### Inputs -- `in` accepts `OpusAudio` (one) +- `in` accepts `EncodedAudio(EncodedAudioFormat { codec: Opus })` (one) ### Outputs - `out` produces `RawAudio(AudioFormat { sample_rate: 48000, channels: 1, sample_format: F32 })` (broadcast) diff --git a/docs/src/content/docs/reference/nodes/audio-opus-encoder.md b/docs/src/content/docs/reference/nodes/audio-opus-encoder.md index 852dd7bb..95257773 100644 --- a/docs/src/content/docs/reference/nodes/audio-opus-encoder.md +++ b/docs/src/content/docs/reference/nodes/audio-opus-encoder.md @@ -19,7 +19,7 @@ Encodes raw PCM audio into Opus-compressed packets. Configurable bitrate, applic - `in` accepts `RawAudio(AudioFormat { sample_rate: 48000, channels: 1, sample_format: F32 })` (one) ### Outputs -- `out` produces `OpusAudio` (broadcast) +- `out` produces `EncodedAudio(EncodedAudioFormat { codec: Opus })` (broadcast) ## Parameters | Name | Type | Required | Default | Description | diff --git a/docs/src/content/docs/reference/nodes/containers-ogg-demuxer.md b/docs/src/content/docs/reference/nodes/containers-ogg-demuxer.md index 506b0d7d..0e586efb 100644 --- a/docs/src/content/docs/reference/nodes/containers-ogg-demuxer.md +++ b/docs/src/content/docs/reference/nodes/containers-ogg-demuxer.md @@ -18,7 +18,7 @@ Demuxes Ogg containers to extract Opus audio packets. Accepts binary Ogg data an - `in` accepts `Binary` (one) ### Outputs -- `out` produces `OpusAudio` (broadcast) +- `out` produces `EncodedAudio(EncodedAudioFormat { codec: Opus })` (broadcast) ## Parameters No parameters. diff --git a/docs/src/content/docs/reference/nodes/containers-ogg-muxer.md b/docs/src/content/docs/reference/nodes/containers-ogg-muxer.md index e474e925..3d46d9c0 100644 --- a/docs/src/content/docs/reference/nodes/containers-ogg-muxer.md +++ b/docs/src/content/docs/reference/nodes/containers-ogg-muxer.md @@ -15,7 +15,7 @@ Muxes Opus audio packets into an Ogg container. Produces streamable Ogg/Opus out ## Pins ### Inputs -- `in` accepts `OpusAudio` (one) +- `in` accepts `EncodedAudio(EncodedAudioFormat { codec: Opus })` (one) ### Outputs - `out` produces `Binary` (broadcast) diff --git a/docs/src/content/docs/reference/nodes/containers-webm-muxer.md b/docs/src/content/docs/reference/nodes/containers-webm-muxer.md index c321ac1b..ae8cd98a 100644 --- a/docs/src/content/docs/reference/nodes/containers-webm-muxer.md +++ b/docs/src/content/docs/reference/nodes/containers-webm-muxer.md @@ -2,20 +2,39 @@ # SPDX-FileCopyrightText: © 2025 StreamKit Contributors # SPDX-License-Identifier: MPL-2.0 title: "containers::webm::muxer" -description: "Muxes Opus audio into a WebM container. Produces streamable WebM/Opus output compatible with web browsers." +description: "Muxes Opus audio and/or VP9 video into a WebM container. Produces streamable WebM output compatible with web browsers." --- `kind`: `containers::webm::muxer` -Muxes Opus audio into a WebM container. Produces streamable WebM/Opus output compatible with web browsers. +Muxes Opus audio and/or VP9 video into a WebM container. Produces streamable WebM output compatible with web browsers. Supports audio-only, video-only, or combined audio+video muxing. ## Categories - `containers` - `webm` ## Pins + +Input pins use generic names — the media type (audio or video) is determined at +connection time from the upstream node's output type, not from the pin name. + +When `video_width` and `video_height` are **not** configured (default), a single +`in` pin is exposed, keeping backward compatibility with existing audio-only +pipelines (`needs: opus_encoder`). + +When video dimensions **are** configured, two pins (`in` + `in_1`) are exposed +so that both an audio and a video encoder can be connected. Use the map syntax +to target each pin explicitly: + +```yaml +needs: + in: opus_encoder + in_1: vp9_encoder +``` + ### Inputs -- `in` accepts `OpusAudio` (one) +- `in` accepts `EncodedAudio(Opus)` or `EncodedVideo(VP9)` (one) +- `in_1` accepts `EncodedAudio(Opus)` or `EncodedVideo(VP9)` (one) — only present when `video_width`/`video_height` > 0 ### Outputs - `out` produces `Binary` (broadcast) @@ -26,7 +45,9 @@ Muxes Opus audio into a WebM container. Produces streamable WebM/Opus output com | `channels` | `integer (uint32)` | no | `2` | Number of audio channels (1 for mono, 2 for stereo)
min: `0` | | `chunk_size` | `integer (uint)` | no | `65536` | The number of bytes to buffer before flushing to the output. Defaults to 65536.
min: `0` | | `sample_rate` | `integer (uint32)` | no | `48000` | Audio sample rate in Hz
min: `0` | -| `streaming_mode` | `string` | no | — | — | +| `streaming_mode` | `string` | no | — | Streaming mode: `"live"` for real-time streaming (no duration), `"file"` for complete files with duration | +| `video_width` | `integer (uint32)` | no | `0` | Video frame width in pixels. Set to > 0 together with `video_height` to enable the second input pin for video.
min: `0` | +| `video_height` | `integer (uint32)` | no | `0` | Video frame height in pixels. Set to > 0 together with `video_width` to enable the second input pin for video.
min: `0` |
diff --git a/docs/src/content/docs/reference/nodes/index.md b/docs/src/content/docs/reference/nodes/index.md index 9787a82b..695e6a80 100644 --- a/docs/src/content/docs/reference/nodes/index.md +++ b/docs/src/content/docs/reference/nodes/index.md @@ -30,6 +30,11 @@ Notes: - [`audio::pacer`](./audio-pacer/) - [`audio::resampler`](./audio-resampler/) +## `video` (2) + +- [`video::vp9::decoder`](./video-vp9-decoder/) +- [`video::vp9::encoder`](./video-vp9-encoder/) + ## `containers` (4) - [`containers::ogg::demuxer`](./containers-ogg-demuxer/) diff --git a/docs/src/content/docs/reference/nodes/transport-moq-peer.md b/docs/src/content/docs/reference/nodes/transport-moq-peer.md index 040b2e98..4ad861cd 100644 --- a/docs/src/content/docs/reference/nodes/transport-moq-peer.md +++ b/docs/src/content/docs/reference/nodes/transport-moq-peer.md @@ -17,10 +17,10 @@ Bidirectional MoQ peer for real-time audio communication. Acts as both publisher ## Pins ### Inputs -- `in` accepts `OpusAudio` (one) +- `in` accepts `EncodedAudio(EncodedAudioFormat { codec: Opus })` (one) ### Outputs -- `out` produces `OpusAudio` (broadcast) +- `out` produces `EncodedAudio(EncodedAudioFormat { codec: Opus })` (broadcast) ## Parameters | Name | Type | Required | Default | Description | diff --git a/docs/src/content/docs/reference/nodes/transport-moq-publisher.md b/docs/src/content/docs/reference/nodes/transport-moq-publisher.md index 2a27328e..c91dcdee 100644 --- a/docs/src/content/docs/reference/nodes/transport-moq-publisher.md +++ b/docs/src/content/docs/reference/nodes/transport-moq-publisher.md @@ -16,7 +16,7 @@ Publishes audio to a Media over QUIC (MoQ) broadcast. Sends Opus audio to subscr ## Pins ### Inputs -- `in` accepts `OpusAudio` (one) +- `in` accepts `EncodedAudio(EncodedAudioFormat { codec: Opus })` (one) ### Outputs No outputs. diff --git a/docs/src/content/docs/reference/nodes/transport-moq-subscriber.md b/docs/src/content/docs/reference/nodes/transport-moq-subscriber.md index 8bc9779d..4a5c58ef 100644 --- a/docs/src/content/docs/reference/nodes/transport-moq-subscriber.md +++ b/docs/src/content/docs/reference/nodes/transport-moq-subscriber.md @@ -19,7 +19,7 @@ Subscribes to a Media over QUIC (MoQ) broadcast. Receives Opus audio from a remo No inputs. ### Outputs -- `out` produces `OpusAudio` (broadcast) +- `out` produces `EncodedAudio(EncodedAudioFormat { codec: Opus })` (broadcast) ## Parameters | Name | Type | Required | Default | Description | diff --git a/docs/src/content/docs/reference/nodes/video-vp9-decoder.md b/docs/src/content/docs/reference/nodes/video-vp9-decoder.md new file mode 100644 index 00000000..d457d9e6 --- /dev/null +++ b/docs/src/content/docs/reference/nodes/video-vp9-decoder.md @@ -0,0 +1,50 @@ +--- +# SPDX-FileCopyrightText: © 2025 StreamKit Contributors +# SPDX-License-Identifier: MPL-2.0 +title: "video::vp9::decoder" +description: "Decodes VP9-compressed packets into raw I420 video frames for CPU processing." +--- + +`kind`: `video::vp9::decoder` + +Decodes VP9-compressed packets into raw I420 video frames for CPU processing. + +## Requirements +- Build with `--features video` and have `libvpx` available via `pkg-config`. + +## Categories +- `video` +- `codecs` +- `vp9` + +## Pins +### Inputs +- `in` accepts `EncodedVideo(EncodedVideoFormat { codec: Vp9 })` (one) + +### Outputs +- `out` produces `RawVideo(VideoFormat { width: *, height: *, pixel_format: I420 })` (broadcast) + +## Parameters +| Name | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `threads` | `integer` | no | `2` | Decoder worker threads | + + +
+Raw JSON Schema + +```json +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "threads": { + "default": 2, + "type": "integer" + } + }, + "title": "Vp9DecoderConfig", + "type": "object" +} +``` + +
diff --git a/docs/src/content/docs/reference/nodes/video-vp9-encoder.md b/docs/src/content/docs/reference/nodes/video-vp9-encoder.md new file mode 100644 index 00000000..1f8ec647 --- /dev/null +++ b/docs/src/content/docs/reference/nodes/video-vp9-encoder.md @@ -0,0 +1,60 @@ +--- +# SPDX-FileCopyrightText: © 2025 StreamKit Contributors +# SPDX-License-Identifier: MPL-2.0 +title: "video::vp9::encoder" +description: "Encodes raw I420 video frames into VP9 packets for transport or container muxing." +--- + +`kind`: `video::vp9::encoder` + +Encodes raw I420 video frames into VP9 packets for transport or container muxing. + +## Requirements +- Build with `--features video` and have `libvpx` available via `pkg-config`. + +## Categories +- `video` +- `codecs` +- `vp9` + +## Pins +### Inputs +- `in` accepts `RawVideo(VideoFormat { width: *, height: *, pixel_format: I420 })` (one) + +### Outputs +- `out` produces `EncodedVideo(EncodedVideoFormat { codec: Vp9 })` (broadcast) + +## Parameters +| Name | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `bitrate_kbps` | `integer` | no | `2500` | Target bitrate in kbps | +| `keyframe_interval` | `integer` | no | `120` | Maximum keyframe interval (frames) | +| `threads` | `integer` | no | `2` | Encoder worker threads | + + +
+Raw JSON Schema + +```json +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "bitrate_kbps": { + "default": 2500, + "type": "integer" + }, + "keyframe_interval": { + "default": 120, + "type": "integer" + }, + "threads": { + "default": 2, + "type": "integer" + } + }, + "title": "Vp9EncoderConfig", + "type": "object" +} +``` + +
diff --git a/docs/src/content/docs/reference/packets/binary.md b/docs/src/content/docs/reference/packets/binary.md index 7acd35aa..8bdc4945 100644 --- a/docs/src/content/docs/reference/packets/binary.md +++ b/docs/src/content/docs/reference/packets/binary.md @@ -23,7 +23,7 @@ Binary packets are carried as: { "data": "", "content_type": "application/octet-stream", - "metadata": { "timestamp_us": 0, "duration_us": 20000, "sequence": 42 } + "metadata": { "timestamp_us": 0, "duration_us": 20000, "sequence": 42, "keyframe": true } } ``` @@ -31,3 +31,4 @@ Notes: - `data` is base64-encoded for JSON transport. - `content_type` is optional and may be `null`. +- `metadata.keyframe` is optional and is used for encoded video packets. diff --git a/docs/src/content/docs/reference/packets/custom.md b/docs/src/content/docs/reference/packets/custom.md index fe31a1bf..47dd07ea 100644 --- a/docs/src/content/docs/reference/packets/custom.md +++ b/docs/src/content/docs/reference/packets/custom.md @@ -31,7 +31,7 @@ Use `Custom` when you need **structured, typed messages** that don't fit existin Prefer other packet types when they fit: -- Audio frames/streams: `/reference/packets/raw-audio/` or `/reference/packets/opus-audio/` +- Audio frames/streams: `/reference/packets/raw-audio/` or `/reference/packets/encoded-audio/` - Plain strings: `/reference/packets/text/` - Opaque bytes, blobs, or media: `/reference/packets/binary/` - Speech-to-text results: `/reference/packets/transcription/` @@ -171,6 +171,13 @@ Custom packets are carried as `Packet::Custom(Arc)`. ], "format": "uint64", "minimum": 0 + }, + "keyframe": { + "description": "Keyframe flag for encoded video packets (and raw frames if applicable)", + "type": [ + "boolean", + "null" + ] } } } diff --git a/docs/src/content/docs/reference/packets/encoded-audio.md b/docs/src/content/docs/reference/packets/encoded-audio.md new file mode 100644 index 00000000..db87e73a --- /dev/null +++ b/docs/src/content/docs/reference/packets/encoded-audio.md @@ -0,0 +1,72 @@ +--- +# SPDX-FileCopyrightText: © 2025 StreamKit Contributors +# SPDX-License-Identifier: MPL-2.0 +title: "Encoded Audio" +description: "PacketType EncodedAudio structure" +--- + +`PacketType` id: `EncodedAudio` + +Type system: `PacketType::EncodedAudio(EncodedAudioFormat)` + +Runtime: `Packet::Binary { data, metadata, .. }` + +## UI Metadata +- `label`: `Encoded Audio` +- `color`: `#ff6b6b` +- `display_template`: `Encoded Audio ({codec})` +- `compat: wildcard fields (codec_private), color: `#ff6b6b`` + +## Structure +Encoded audio packets use `Packet::Binary`, with codec identity captured in the type system. + +### PacketType payload (`EncodedAudioFormat`) + +| Name | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `codec` | `string enum[Opus]` | yes | — | Encoded audio codec. | +| `codec_private` | `null | array` | no | — | Optional codec-specific extradata. Use `null` as a wildcard. | + +
+Raw JSON Schema + +```json +{ + "$defs": { + "AudioCodec": { + "description": "Supported encoded audio codecs.", + "enum": [ + "Opus" + ], + "type": "string" + } + }, + "$schema": "https://json-schema.org/draft/2020-12/schema", + "description": "Encoded audio format details (extensible for codec-specific config).", + "properties": { + "codec": { + "$ref": "#/$defs/AudioCodec" + }, + "codec_private": { + "description": "Optional codec-specific extradata.", + "items": { + "format": "uint8", + "maximum": 255, + "minimum": 0, + "type": "integer" + }, + "type": [ + "array", + "null" + ] + } + }, + "required": [ + "codec" + ], + "title": "EncodedAudioFormat", + "type": "object" +} +``` + +
diff --git a/docs/src/content/docs/reference/packets/encoded-video.md b/docs/src/content/docs/reference/packets/encoded-video.md new file mode 100644 index 00000000..e82695ad --- /dev/null +++ b/docs/src/content/docs/reference/packets/encoded-video.md @@ -0,0 +1,110 @@ +--- +# SPDX-FileCopyrightText: © 2025 StreamKit Contributors +# SPDX-License-Identifier: MPL-2.0 +title: "Encoded Video" +description: "PacketType EncodedVideo structure" +--- + +`PacketType` id: `EncodedVideo` + +Type system: `PacketType::EncodedVideo(EncodedVideoFormat)` + +Runtime: `Packet::Binary { data, metadata, .. }` + +## UI Metadata +- `label`: `Encoded Video` +- `color`: `#2980b9` +- `display_template`: `Encoded Video ({codec})` +- `compat: wildcard fields (bitstream_format, codec_private, profile, level), color: `#2980b9`` + +## Structure +Encoded video packets use `Packet::Binary`, with codec identity captured in the type system. + +### PacketType payload (`EncodedVideoFormat`) + +| Name | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `codec` | `string enum[Vp9, H264, Av1]` | yes | — | Encoded video codec. | +| `bitstream_format` | `null | string enum[AnnexB, Avcc]` | no | — | Bitstream format hint (primarily for H264). Use `null` as a wildcard. | +| `codec_private` | `null | array` | no | — | Optional codec-specific extradata. Use `null` as a wildcard. | +| `profile` | `null | string` | no | — | Optional codec profile hint. Use `null` as a wildcard. | +| `level` | `null | string` | no | — | Optional codec level hint. Use `null` as a wildcard. | + +
+Raw JSON Schema + +```json +{ + "$defs": { + "VideoBitstreamFormat": { + "description": "Bitstream format hints for video codecs (primarily H264).", + "enum": [ + "AnnexB", + "Avcc" + ], + "type": "string" + }, + "VideoCodec": { + "description": "Supported encoded video codecs.", + "enum": [ + "Vp9", + "H264", + "Av1" + ], + "type": "string" + } + }, + "$schema": "https://json-schema.org/draft/2020-12/schema", + "description": "Encoded video format details (extensible for codec-specific config).", + "properties": { + "bitstream_format": { + "anyOf": [ + { + "$ref": "#/$defs/VideoBitstreamFormat" + }, + { + "type": "null" + } + ], + "description": "Bitstream format hint (primarily for H264)." + }, + "codec": { + "$ref": "#/$defs/VideoCodec" + }, + "codec_private": { + "description": "Optional codec-specific extradata.", + "items": { + "format": "uint8", + "maximum": 255, + "minimum": 0, + "type": "integer" + }, + "type": [ + "array", + "null" + ] + }, + "level": { + "description": "Optional codec level hint.", + "type": [ + "string", + "null" + ] + }, + "profile": { + "description": "Optional codec profile hint.", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "codec" + ], + "title": "EncodedVideoFormat", + "type": "object" +} +``` + +
diff --git a/docs/src/content/docs/reference/packets/index.md b/docs/src/content/docs/reference/packets/index.md index f53f0f10..79c48136 100644 --- a/docs/src/content/docs/reference/packets/index.md +++ b/docs/src/content/docs/reference/packets/index.md @@ -18,7 +18,9 @@ curl http://localhost:4545/api/v1/schema/packets | PacketType | Link | Runtime representation | Notes | | --- | --- | --- | --- | | `RawAudio` | [**Raw Audio**](./raw-audio/) | `Packet::Audio(AudioFrame)` | compat: wildcard fields (sample_rate, channels, sample_format), color: `#f39c12` | -| `OpusAudio` | [**Opus Audio**](./opus-audio/) | `Packet::Binary { data, metadata, .. }` | compat: exact, color: `#ff6b6b` | +| `RawVideo` | [**Raw Video**](./raw-video/) | `Packet::Video(VideoFrame)` | compat: wildcard fields (width, height, pixel_format), color: `#1abc9c` | +| `EncodedAudio` | [**Encoded Audio**](./encoded-audio/) | `Packet::Binary { data, metadata, .. }` | compat: wildcard fields (codec_private), color: `#ff6b6b` | +| `EncodedVideo` | [**Encoded Video**](./encoded-video/) | `Packet::Binary { data, metadata, .. }` | compat: wildcard fields (bitstream_format, codec_private, profile, level), color: `#2980b9` | | `Text` | [**Text**](./text/) | `Packet::Text(Arc)` | compat: exact, color: `#4ecdc4` | | `Transcription` | [**Transcription**](./transcription/) | `Packet::Transcription(Arc)` | compat: exact, color: `#9b59b6` | | `Custom` | [**Custom**](./custom/) | `Packet::Custom(Arc)` | compat: wildcard fields (type_id), color: `#e67e22` | @@ -31,4 +33,4 @@ curl http://localhost:4545/api/v1/schema/packets `PacketType` serializes as: - A string for unit variants (e.g., `"Text"`, `"Binary"`). -- An object for payload variants (e.g., `{"RawAudio": {"sample_rate": 48000, ...}}`). +- An object for payload variants (e.g., `{"RawAudio": {"sample_rate": 48000, ...}}` or `{"EncodedAudio": {"codec": "Opus"}}`). diff --git a/docs/src/content/docs/reference/packets/opus-audio.md b/docs/src/content/docs/reference/packets/opus-audio.md deleted file mode 100644 index 8f90f61b..00000000 --- a/docs/src/content/docs/reference/packets/opus-audio.md +++ /dev/null @@ -1,22 +0,0 @@ ---- -# SPDX-FileCopyrightText: © 2025 StreamKit Contributors -# SPDX-License-Identifier: MPL-2.0 -title: "Opus Audio" -description: "PacketType OpusAudio structure" ---- - -`PacketType` id: `OpusAudio` - -Type system: `PacketType::OpusAudio` - -Runtime: `Packet::Binary { data, metadata, .. }` - -## UI Metadata -- `label`: `Opus Audio` -- `color`: `#ff6b6b` -- `compat: exact, color: `#ff6b6b`` - -## Structure -Opus packets use the `OpusAudio` packet type, but the runtime payload is still `Packet::Binary`. - -The Opus codec nodes encode/decode using `Packet::Binary { data, metadata, .. }` and label pins as `OpusAudio`. diff --git a/docs/src/content/docs/reference/packets/raw-video.md b/docs/src/content/docs/reference/packets/raw-video.md new file mode 100644 index 00000000..6c56c9a0 --- /dev/null +++ b/docs/src/content/docs/reference/packets/raw-video.md @@ -0,0 +1,99 @@ +--- +# SPDX-FileCopyrightText: © 2025 StreamKit Contributors +# SPDX-License-Identifier: MPL-2.0 +title: "Raw Video" +description: "PacketType RawVideo structure" +--- + +`PacketType` id: `RawVideo` + +Type system: `PacketType::RawVideo(VideoFormat)` + +Runtime: `Packet::Video(VideoFrame)` + +## UI Metadata +- `label`: `Raw Video` +- `color`: `#1abc9c` +- `display_template`: `Raw Video ({width|*}x{height|*}, {pixel_format})` +- `compat: wildcard fields (width, height), color: `#1abc9c`` + +## Structure +Raw video is defined by a `VideoFormat` in the type system and carried as `Packet::Video(VideoFrame)` at runtime. + +Use `null` for `width` or `height` when you want wildcard/unknown dimensions. + +### PacketType payload (`VideoFormat`) + +| Name | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `width` | `null | integer (uint32)` | no | — | Frame width in pixels. `null` acts as a wildcard. | +| `height` | `null | integer (uint32)` | no | — | Frame height in pixels. `null` acts as a wildcard. | +| `pixel_format` | `string enum[Rgba8, I420]` | yes | — | Pixel format for raw frames. | + +
+Raw JSON Schema + +```json +{ + "$defs": { + "PixelFormat": { + "description": "Describes the pixel format of raw video frames.", + "enum": [ + "Rgba8", + "I420" + ], + "type": "string" + } + }, + "$schema": "https://json-schema.org/draft/2020-12/schema", + "description": "Contains the detailed metadata for a raw video stream.", + "properties": { + "height": { + "format": "uint32", + "minimum": 0, + "type": [ + "integer", + "null" + ] + }, + "pixel_format": { + "$ref": "#/$defs/PixelFormat" + }, + "width": { + "format": "uint32", + "minimum": 0, + "type": [ + "integer", + "null" + ] + } + }, + "required": [ + "pixel_format" + ], + "title": "VideoFormat", + "type": "object" +} +``` + +
+ +### Runtime payload (`VideoFrame`) + +`VideoFrame` is optimized for zero-copy fan-out. It contains: + +- `width` (u32) +- `height` (u32) +- `pixel_format` (`PixelFormat`) +- `layout` (`VideoLayout`, includes per-plane offsets/strides and `stride_align`) +- `data` (packed bytes; layout depends on the pixel format) +- `metadata` (`PacketMetadata`, optional) + +`VideoLayout` exposes: +- `plane_count` +- `planes[]` with `offset`, `stride`, `width`, `height` +- `total_bytes` +- `stride_align` (byte alignment used for each plane stride) + +StreamKit assumes raw video frames use a canonical aligned layout (as produced by `VideoLayout::aligned`). +Codec nodes may reject frames whose layout does not match the expected canonical layout. diff --git a/docs/src/content/docs/reference/packets/transcription.md b/docs/src/content/docs/reference/packets/transcription.md index 157691e7..c6fb3ccb 100644 --- a/docs/src/content/docs/reference/packets/transcription.md +++ b/docs/src/content/docs/reference/packets/transcription.md @@ -145,6 +145,13 @@ Transcriptions are carried as `Packet::Transcription(Arc)`. ], "format": "uint64", "minimum": 0 + }, + "keyframe": { + "description": "Keyframe flag for encoded video packets (and raw frames if applicable)", + "type": [ + "boolean", + "null" + ] } } } diff --git a/e2e/playwright.config.ts b/e2e/playwright.config.ts index 53e7b158..490900ab 100644 --- a/e2e/playwright.config.ts +++ b/e2e/playwright.config.ts @@ -50,7 +50,7 @@ export default defineConfig({ `--use-file-for-fake-audio-capture=${fakeAudioPath}`, ], }, - permissions: ['microphone'], + permissions: ['microphone', 'camera'], }, }, ], diff --git a/e2e/tests/compositor-perf.spec.ts b/e2e/tests/compositor-perf.spec.ts new file mode 100644 index 00000000..ed1a506c --- /dev/null +++ b/e2e/tests/compositor-perf.spec.ts @@ -0,0 +1,423 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +/** + * Layer 2 — Compositor slider interaction perf test. + * + * Creates a Webcam PiP pipeline session via the API, navigates to the monitor + * view where the full compositor node graph is rendered, then selects each + * layer and drags its opacity and rotation sliders while measuring re-renders + * via `window.__PERF_DATA__`. + * + * The test asserts that render counts stay within budget — specifically that + * slider interactions on one layer do NOT trigger expensive cascade + * re-renders in unrelated components (the same regression PR #89 fixed). + * + * NOTE: This test requires the Vite dev server (`just ui`) because the + * profiler store (`window.__PERF_DATA__`) is only exposed when + * `import.meta.env.DEV` is true. Point E2E_BASE_URL at + * http://localhost:3045 (or wherever the dev server runs). + */ + +import { test, expect, request } from '@playwright/test'; + +import { ensureLoggedIn, getAuthHeaders } from './auth-helpers'; +import { + type ConsoleErrorCollector, + MOQ_BENIGN_PATTERNS, + createConsoleErrorCollector, +} from './test-helpers'; +import { + resetPerfData, + capturePerfData, + assertRenderBudget, + formatPerfSummary, +} from './perf-helpers'; + +// --------------------------------------------------------------------------- +// Pipeline YAML — Webcam PiP compositor +// Embedded so the test is self-contained and does not depend on file paths. +// --------------------------------------------------------------------------- + +const WEBCAM_PIP_YAML = ` +name: Webcam PiP (MoQ Stream) +description: Composites the user's webcam as picture-in-picture over colorbars with a text overlay +mode: dynamic + +nodes: + colorbars_bg: + kind: video::colorbars + params: + width: 1280 + height: 720 + fps: 30 + draw_time: true + + moq_peer: + kind: transport::moq::peer + params: + gateway_path: /moq/video + input_broadcast: input + output_broadcast: output + allow_reconnect: true + needs: + in: opus_encoder + in_1: vp9_encoder + + vp9_decoder: + kind: video::vp9::decoder + needs: + in: moq_peer.out_1 + + compositor: + kind: video::compositor + params: + width: 1280 + height: 720 + num_inputs: 2 + layers: + in_0: + opacity: 1.0 + z_index: 0 + in_1: + rect: + x: 880 + y: 20 + width: 380 + height: 285 + opacity: 0.95 + z_index: 1 + text_overlays: + - text: "Hello from StreamKit" + transform: + rect: + x: 40 + y: 660 + width: 400 + height: 40 + opacity: 1.0 + z_index: 2 + color: [255, 255, 255, 220] + font_size: 28 + font_name: dejavu-sans-bold + needs: + - colorbars_bg + - vp9_decoder + + pixel_convert: + kind: video::pixel_convert + params: + output_format: nv12 + needs: compositor + + vp9_encoder: + kind: video::vp9::encoder + params: + keyframe_interval: 30 + needs: pixel_convert + + opus_decoder: + kind: audio::opus::decoder + needs: moq_peer + + gain: + kind: audio::gain + params: + gain: 1.0 + needs: opus_decoder + + opus_encoder: + kind: audio::opus::encoder + needs: gain +`.trim(); + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/** + * Simulate dragging a Radix slider thumb horizontally by `deltaX` pixels. + * The thumb is located via its `role="slider"` within the given container. + */ +async function dragSliderThumb( + page: import('@playwright/test').Page, + container: import('@playwright/test').Locator, + deltaX: number +) { + const thumb = container.getByRole('slider'); + await thumb.waitFor({ state: 'visible', timeout: 5_000 }); + const box = await thumb.boundingBox(); + if (!box) throw new Error('Slider thumb has no bounding box'); + + const startX = box.x + box.width / 2; + const startY = box.y + box.height / 2; + + await page.mouse.move(startX, startY); + await page.mouse.down(); + + // Move in small increments to simulate a realistic drag that fires + // multiple onValueChange events. + const steps = 20; + const stepSize = deltaX / steps; + for (let i = 1; i <= steps; i++) { + await page.mouse.move(startX + stepSize * i, startY); + } + + await page.mouse.up(); +} + +// --------------------------------------------------------------------------- +// Test +// --------------------------------------------------------------------------- + +test.describe('Compositor Slider Perf — Cascade Re-render Budget', () => { + let collector: ConsoleErrorCollector; + let sessionId: string | null = null; + + test.beforeEach(async ({ page }) => { + collector = createConsoleErrorCollector(page); + }); + + test('slider drags stay within render budget across all compositor components', async ({ + page, + baseURL, + }) => { + // This test involves API session creation + multiple slider interactions. + test.setTimeout(120_000); + + // ── 1. Create Webcam PiP session via API ──────────────────────────── + // + // Using the API avoids the stream view flow and MoQ WebTransport + // connection, which is unreliable in headless CI environments. + + const apiContext = await request.newContext({ + baseURL: baseURL!, + extraHTTPHeaders: getAuthHeaders(), + }); + + const createResponse = await apiContext.post('/api/v1/sessions', { + data: { + name: `perf-test-${Date.now()}`, + yaml: WEBCAM_PIP_YAML, + }, + }); + + const responseText = await createResponse.text(); + expect( + createResponse.ok(), + `Failed to create session: ${responseText}` + ).toBeTruthy(); + + const createData = JSON.parse(responseText) as { session_id: string }; + sessionId = createData.session_id; + expect(sessionId).toBeTruthy(); + await apiContext.dispose(); + + // ── 2. Navigate to monitor view ───────────────────────────────────── + + await page.goto('/monitor'); + await ensureLoggedIn(page); + if (!page.url().includes('/monitor')) { + await page.goto('/monitor'); + } + await expect(page.getByTestId('monitor-view')).toBeVisible({ + timeout: 15_000, + }); + + // Wait for sessions list and click our session. + await expect(page.getByTestId('sessions-list')).toBeVisible({ + timeout: 10_000, + }); + + const sessionItem = page.getByTestId('session-item').first(); + await expect(sessionItem).toBeVisible({ timeout: 10_000 }); + await sessionItem.click(); + + // Wait for the React Flow canvas and compositor node to render. + await expect(page.locator('.react-flow__node').first()).toBeVisible({ + timeout: 15_000, + }); + + // Allow initial renders to settle. + await page.waitForTimeout(2_000); + + // ── 3. Verify dev-mode profiler is available ──────────────────────── + + const hasPerfData = await page.evaluate(() => { + const w = window as Window & { + __PERF_DATA__?: unknown; + __PERF_RESET__?: unknown; + }; + return !!w.__PERF_DATA__ && !!w.__PERF_RESET__; + }); + + if (!hasPerfData) { + test.skip( + true, + 'window.__PERF_DATA__ not found — test requires the Vite dev server (just ui)' + ); + } + + // ── 4. Locate compositor node and its layer list ──────────────────── + + // The compositor node is the React Flow node containing "Compositor". + const compositorNode = page.locator('.react-flow__node').filter({ + hasText: 'Compositor', + }); + await expect(compositorNode).toBeVisible({ timeout: 10_000 }); + + // Layer names in the PiP pipeline: "Text 0", "Input 1", "Input 0". + // These are plain
elements inside the layer list — no test IDs + // or
  • wrappers. We locate them by exact text content within the + // compositor node. + const layerNames = ['Text 0', 'Input 1', 'Input 0']; + const availableLayers: string[] = []; + + for (const name of layerNames) { + const layerDiv = compositorNode.getByText(name, { exact: true }); + if (await layerDiv.first().isVisible().catch(() => false)) { + availableLayers.push(name); + } + } + + if (availableLayers.length === 0) { + test.skip( + true, + 'No compositor layers found — pipeline may not have initialised' + ); + } + + console.log( + `Found ${availableLayers.length} layer(s): ${availableLayers.join(', ')}` + ); + + // ── 5. Measure slider interactions per layer ──────────────────────── + + // Reset perf data before our measurement window. + await resetPerfData(page); + + for (const layerName of availableLayers) { + // Click the layer in the layer list to select it and open inspector. + const layerDiv = compositorNode.getByText(layerName, { exact: true }); + await layerDiv.first().click(); + await page.waitForTimeout(500); // let inspector render + + // --- Opacity slider --- + // The inspector shows an "Opacity" label followed by a Radix slider + // (role="slider"). We locate the innermost div containing "Opacity" + // that also holds a slider thumb. + const opacitySection = compositorNode + .locator('div') + .filter({ hasText: /^Opacity/ }) + .filter({ has: page.getByRole('slider') }) + .first(); + + const hasOpacity = await opacitySection + .getByRole('slider') + .isVisible() + .catch(() => false); + + if (hasOpacity) { + console.log(` Dragging opacity slider for "${layerName}"`); + await dragSliderThumb(page, opacitySection, 40); + await page.waitForTimeout(100); + // Drag back to exercise more render cycles. + await dragSliderThumb(page, opacitySection, -40); + await page.waitForTimeout(100); + } + + // --- Rotation slider --- + // Similar approach: find the section labelled "Rotation" that + // contains a slider. (Rotation also has preset buttons like + // 0°/90°/180°/270° but we specifically target the slider.) + const rotationSection = compositorNode + .locator('div') + .filter({ hasText: /^Rotation/ }) + .filter({ has: page.getByRole('slider') }) + .first(); + + const hasRotation = await rotationSection + .getByRole('slider') + .isVisible() + .catch(() => false); + + if (hasRotation) { + console.log(` Dragging rotation slider for "${layerName}"`); + await dragSliderThumb(page, rotationSection, 60); + await page.waitForTimeout(100); + await dragSliderThumb(page, rotationSection, -60); + await page.waitForTimeout(100); + } + } + + // ── 6. Capture and assert render budgets ──────────────────────────── + + const snapshot = await capturePerfData(page); + console.log('\n' + formatPerfSummary(snapshot)); + + // CompositorNode itself will re-render on each slider tick — but with + // proper memoization the total should stay bounded. The budget below + // is generous (accommodates CI jitter) while still catching the + // pre-PR-#89 regression where every slider tick caused 94+ fiber + // re-renders across the entire tree. + // + // Observed baseline: ~385 renders for the full 3-layer scenario. + // Budget of 500 gives ~30% headroom for CI jitter while still + // catching the pre-PR-#89 regression (thousands of renders when + // every slider tick triggered 94+ fiber re-renders). + const compositorData = snapshot.components['CompositorNode']; + if (compositorData) { + assertRenderBudget(snapshot, 'CompositorNode', { + max: 500, + maxDuration: 5_000, + }); + } + + // The key cascade assertion: if there are OTHER profiled components + // (siblings rendered outside the active slider path), their render + // count should be dramatically lower than CompositorNode's. A cascade + // regression would show them rendering at a similar rate. + for (const [id, data] of Object.entries(snapshot.components)) { + if (id === 'CompositorNode') continue; + // Sibling components should not exceed a fraction of the compositor's + // render count — generous 60% ceiling to allow for legitimate + // re-renders while catching full-cascade regressions. + if (compositorData && data.renderCount > compositorData.renderCount * 0.6) { + throw new Error( + `Cascade regression detected: "${id}" rendered ${data.renderCount} times ` + + `(${((data.renderCount / compositorData.renderCount) * 100).toFixed(0)}% of CompositorNode's ` + + `${compositorData.renderCount}). This suggests slider interactions are causing ` + + `expensive re-renders in unrelated components.` + ); + } + } + + // ── 7. Console error check ────────────────────────────────────────── + + const unexpected = collector.getUnexpected(MOQ_BENIGN_PATTERNS); + // Log but don't fail — monitor view may have transient warnings during + // session state transitions that aren't perf-related. + if (unexpected.length > 0) { + console.warn('Unexpected console errors (non-fatal):', unexpected); + } + }); + + // ── Cleanup ───────────────────────────────────────────────────────────── + + test.afterEach(async ({ baseURL }) => { + if (sessionId) { + try { + const apiContext = await request.newContext({ + baseURL: baseURL!, + extraHTTPHeaders: getAuthHeaders(), + }); + await apiContext.delete(`/api/v1/sessions/${sessionId}`); + await apiContext.dispose(); + } catch { + // Best-effort cleanup; ignore errors. + } + sessionId = null; + } + }); +}); diff --git a/e2e/tests/convert.spec.ts b/e2e/tests/convert.spec.ts index 341aa35a..49673350 100644 --- a/e2e/tests/convert.spec.ts +++ b/e2e/tests/convert.spec.ts @@ -12,6 +12,7 @@ import { type ConsoleErrorCollector, createConsoleErrorCollector, verifyAudioPlayback, + verifyVideoPlayback, } from './test-helpers'; const repoRoot = path.resolve(import.meta.dirname, '..', '..'); @@ -171,3 +172,49 @@ test.describe('Convert View - Audio Mixing Pipeline', () => { expect(unexpected, `Unexpected console errors: ${unexpected.join('; ')}`).toHaveLength(0); }); }); + +test.describe('Convert View - Video Color Bars Pipeline', () => { + let collector: ConsoleErrorCollector; + + test.beforeEach(async ({ page }) => { + collector = createConsoleErrorCollector(page); + await page.goto('/convert'); + await ensureLoggedIn(page); + if (!page.url().includes('/convert')) { + await page.goto('/convert'); + } + await expect(page.getByTestId('convert-view')).toBeVisible(); + }); + + test('UI: select video colorbars template, generate, verify video player', async ({ page }) => { + // VP9 encoding can be slow; give the full pipeline up to 120s. + test.setTimeout(120_000); + + await expect(page.getByText('1. Select Pipeline Template')).toBeVisible(); + + const templateCard = page.getByText('Video Color Bars (VP9/WebM)', { + exact: true, + }); + await expect(templateCard).toBeVisible({ timeout: 10_000 }); + await templateCard.click(); + + // This is a no-input (generator) pipeline, so the button says "Generate". + const generateButton = page.getByRole('button', { name: /Generate/i }); + await expect(generateButton).toBeEnabled(); + await generateButton.click(); + + // Wait for the video output to appear. + await expect(page.getByText('Converted Video')).toBeVisible({ + timeout: 90_000, + }); + + const playback = await verifyVideoPlayback(page); + expect(playback.found, 'Video element not found on page').toBe(true); + expect(playback.readyState, 'Video not ready').toBeGreaterThanOrEqual(1); + expect(playback.videoWidth, 'Video has no width').toBeGreaterThan(0); + expect(playback.videoHeight, 'Video has no height').toBeGreaterThan(0); + + const unexpected = collector.getUnexpected(); + expect(unexpected, `Unexpected console errors: ${unexpected.join('; ')}`).toHaveLength(0); + }); +}); diff --git a/e2e/tests/monitor.spec.ts b/e2e/tests/monitor.spec.ts index d0905e10..60c2c086 100644 --- a/e2e/tests/monitor.spec.ts +++ b/e2e/tests/monitor.spec.ts @@ -85,6 +85,64 @@ steps: } }); + test('deleted session does not reappear in the list (race condition regression)', async ({ + page, + baseURL, + }) => { + const apiContext = await request.newContext({ + baseURL: baseURL!, + extraHTTPHeaders: getAuthHeaders(), + }); + + try { + // Step 1: Create a session via the API. + const createResponse = await apiContext.post('/api/v1/sessions', { + data: { + name: testSessionName, + yaml: minimalPipelineYaml, + }, + }); + const responseText = await createResponse.text(); + expect(createResponse.ok(), `Create session failed: ${responseText}`).toBeTruthy(); + const createData = JSON.parse(responseText) as { session_id: string }; + sessionId = createData.session_id; + expect(sessionId).toBeTruthy(); + + // Step 2: Reload and wait for the session to appear. + await page.reload(); + await expect(page.getByTestId('monitor-view')).toBeVisible(); + await expect(page.getByTestId('sessions-list')).toBeVisible({ timeout: 10000 }); + + const sessionItem = page.getByTestId('session-item').filter({ hasText: testSessionName }); + await expect(sessionItem).toBeVisible({ timeout: 10000 }); + + // Step 3: Delete the session via the UI. + await sessionItem.hover(); + const deleteButton = sessionItem.getByTestId('session-delete-btn'); + await expect(deleteButton).toBeVisible(); + await deleteButton.click(); + + const confirmModal = page.getByTestId('confirm-modal'); + await expect(confirmModal).toBeVisible(); + await confirmModal.getByRole('button', { name: /Confirm|Delete/i }).click(); + + // Step 4: Verify the session disappears. + await expect(sessionItem).toHaveCount(0, { timeout: 10000 }); + + // Step 5: Wait several seconds and assert the session does NOT reappear. + // The old code would re-fetch the session list immediately after the + // WebSocket event, and the server could still return the stale session + // causing a brief flicker. With the fix the optimistic removal prevents + // any reappearance. + await page.waitForTimeout(5000); + await expect(sessionItem).toHaveCount(0); + + sessionId = null; + } finally { + await apiContext.dispose(); + } + }); + test.afterEach(async ({ baseURL }) => { // Cleanup: ensure session is deleted even if test fails if (sessionId) { diff --git a/e2e/tests/perf-helpers.ts b/e2e/tests/perf-helpers.ts new file mode 100644 index 00000000..33989fb4 --- /dev/null +++ b/e2e/tests/perf-helpers.ts @@ -0,0 +1,179 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +/** + * Playwright helpers for Layer 2 render-performance profiling. + * + * These utilities interact with the dev-only `window.__PERF_DATA__` store + * exposed by `ui/src/perf/profiler.ts`. They allow Playwright tests to: + * + * 1. Reset profiling data before a scenario. + * 2. Run an interaction (slider drag, button clicks, etc.). + * 3. Capture a snapshot of render metrics. + * 4. Compare against a previous snapshot or baseline. + * + * @example + * const before = await resetPerfData(page); + * await dragSlider(page, selector, 100); + * const snapshot = await capturePerfData(page); + * expectRenderCount(snapshot, 'CompositorNode', { max: 60 }); + */ + +import type { Page } from '@playwright/test'; + +// ── Types mirroring ui/src/perf/profiler.ts ────────────────────────────────── + +export interface PerfCommit { + id: string; + phase: 'mount' | 'update' | 'nested-update'; + actualDuration: number; + baseDuration: number; + startTime: number; + commitTime: number; +} + +export interface PerfComponentData { + renderCount: number; + totalDuration: number; + maxCommitDuration: number; + commits: PerfCommit[]; +} + +export interface PerfSnapshot { + components: Record; + session: number; + startedAt: string; +} + +// ── Core helpers ───────────────────────────────────────────────────────────── + +/** + * Reset the in-app perf profiler and return the (empty) initial state. + * Must be called before the interaction you want to measure. + */ +export async function resetPerfData(page: Page): Promise { + await page.evaluate(() => { + const w = window as Window & { __PERF_RESET__?: () => void }; + if (w.__PERF_RESET__) { + w.__PERF_RESET__(); + } else { + throw new Error( + 'window.__PERF_RESET__ not found — is the app running in dev mode?' + ); + } + }); +} + +/** + * Capture the current perf data snapshot from the running app. + */ +export async function capturePerfData(page: Page): Promise { + return page.evaluate(() => { + const w = window as Window & { __PERF_DATA__?: PerfSnapshot }; + if (!w.__PERF_DATA__) { + throw new Error( + 'window.__PERF_DATA__ not found — is the app running in dev mode?' + ); + } + // Deep clone to avoid stale references. + return JSON.parse(JSON.stringify(w.__PERF_DATA__)) as PerfSnapshot; + }); +} + +// ── Comparison utilities ───────────────────────────────────────────────────── + +export interface RenderBudget { + /** Maximum allowed render count. Exceeding this fails the assertion. */ + max?: number; + /** Maximum total duration in ms. */ + maxDuration?: number; +} + +/** + * Assert that a component's render metrics fall within the given budget. + * Throws a descriptive error if the budget is exceeded. + */ +export function assertRenderBudget( + snapshot: PerfSnapshot, + componentId: string, + budget: RenderBudget +): void { + const data = snapshot.components[componentId]; + if (!data) { + throw new Error( + `No perf data for "${componentId}". ` + + `Available components: ${Object.keys(snapshot.components).join(', ') || '(none)'}` + ); + } + + if (budget.max !== undefined && data.renderCount > budget.max) { + throw new Error( + `"${componentId}" rendered ${data.renderCount} times, ` + + `exceeding budget of ${budget.max}.` + ); + } + + if (budget.maxDuration !== undefined && data.totalDuration > budget.maxDuration) { + throw new Error( + `"${componentId}" total render duration was ${data.totalDuration.toFixed(1)}ms, ` + + `exceeding budget of ${budget.maxDuration}ms.` + ); + } +} + +/** + * Compare two perf snapshots and return per-component deltas. + */ +export function compareSnapshots( + before: PerfSnapshot, + after: PerfSnapshot +): Record< + string, + { renderCountDelta: number; durationDelta: number; renderCount: number; totalDuration: number } +> { + const result: Record< + string, + { renderCountDelta: number; durationDelta: number; renderCount: number; totalDuration: number } + > = {}; + + for (const [id, afterData] of Object.entries(after.components)) { + const beforeData = before.components[id]; + result[id] = { + renderCount: afterData.renderCount, + totalDuration: afterData.totalDuration, + renderCountDelta: afterData.renderCount - (beforeData?.renderCount ?? 0), + durationDelta: afterData.totalDuration - (beforeData?.totalDuration ?? 0), + }; + } + + return result; +} + +/** + * Format a snapshot into a human-readable summary (for test output / CI logs). + */ +export function formatPerfSummary(snapshot: PerfSnapshot): string { + const lines: string[] = [ + `Perf Snapshot (session ${snapshot.session}, started ${snapshot.startedAt})`, + '-'.repeat(60), + ]; + + const entries = Object.entries(snapshot.components).sort( + ([, a], [, b]) => b.renderCount - a.renderCount + ); + + for (const [id, data] of entries) { + lines.push( + ` ${id}: ${data.renderCount} renders, ` + + `${data.totalDuration.toFixed(1)}ms total, ` + + `${data.maxCommitDuration.toFixed(1)}ms max` + ); + } + + if (entries.length === 0) { + lines.push(' (no components profiled)'); + } + + return lines.join('\n'); +} diff --git a/e2e/tests/stream.spec.ts b/e2e/tests/stream.spec.ts index 4043e1c9..44f8c208 100644 --- a/e2e/tests/stream.spec.ts +++ b/e2e/tests/stream.spec.ts @@ -11,6 +11,7 @@ import { createConsoleErrorCollector, installAudioContextTracker, verifyAudioContextActive, + verifyCanvasRendering, } from './test-helpers'; test.describe('Stream View - Dynamic Pipeline', () => { @@ -200,3 +201,297 @@ test.describe('Stream View - Dynamic Pipeline', () => { } }); }); + +test.describe('Stream View - Video MoQ Color Bars Pipeline', () => { + let collector: ConsoleErrorCollector; + let sessionId: string | null = null; + + test.beforeEach(async ({ page }) => { + collector = createConsoleErrorCollector(page); + await installAudioContextTracker(page); + await page.goto('/stream'); + await ensureLoggedIn(page); + if (!page.url().includes('/stream')) { + await page.goto('/stream'); + } + await expect(page.getByTestId('stream-view')).toBeVisible(); + }); + + test('creates video session, connects via MoQ, verifies canvas rendering', async ({ + page, + baseURL, + }) => { + test.setTimeout(90_000); + + // Check MoQ gateway availability; skip if not configured. + const configResponse = await page.request.get(`${baseURL}/api/v1/config`); + if (configResponse.ok()) { + const config = (await configResponse.json()) as { + moq_gateway_url?: string | null; + }; + if (!config.moq_gateway_url) { + test.skip(true, 'MoQ gateway not configured on this server'); + } + } + + // Select the video colorbars MoQ template. + const templateCard = page.getByText('Video Color Bars (MoQ Stream)', { + exact: true, + }); + await expect(templateCard).toBeVisible({ timeout: 10_000 }); + await templateCard.click(); + + // Create session. + const createButton = page.getByRole('button', { name: /Create Session/i }); + await expect(createButton).toBeEnabled({ timeout: 5_000 }); + await createButton.click(); + + const activeBadge = page.getByText('Session Active'); + await expect(activeBadge).toBeVisible({ timeout: 15_000 }); + + // Extract session ID for cleanup. + const sessionIdText = await page.getByText(/Session ID:/).textContent(); + sessionId = sessionIdText?.replace(/Session ID:\s*/, '').trim() ?? null; + + // Wait for MoQ connection (auto-connect or manual). + const connected = page.getByText('Relay: connected'); + const disconnected = page.getByText('Disconnected'); + const connectButton = page.getByRole('button', { + name: /Connect & Stream/i, + }); + + await expect(connected.or(connectButton)).toBeVisible({ timeout: 20_000 }); + + const isConnected = await connected.isVisible(); + if (!isConnected) { + await expect(connectButton).toBeEnabled({ timeout: 5_000 }); + await connectButton.click(); + await expect(connected.or(disconnected)).toBeVisible({ timeout: 20_000 }); + } + + const finalConnected = await connected.isVisible(); + if (finalConnected) { + // Wait for the watch path to go live. + await expect(page.getByText(/Watch: live/)).toBeVisible({ + timeout: 15_000, + }); + + // Scroll the canvas into view so the IntersectionObserver fires and + // the video decoder subscribes to the video/data track. + const canvas = page.locator('canvas'); + await expect(canvas).toBeVisible({ timeout: 5_000 }); + await canvas.scrollIntoViewIfNeeded(); + + // Give the video decoder time to render a few frames onto the canvas. + await page.waitForTimeout(3_000); + + // Verify canvas is rendering non-black pixels (SMPTE color bars). + const canvasState = await verifyCanvasRendering(page); + expect(canvasState.found, 'Canvas element not found on page').toBe(true); + expect(canvasState.width, 'Canvas has no width').toBeGreaterThan(0); + expect(canvasState.height, 'Canvas has no height').toBeGreaterThan(0); + expect( + canvasState.hasNonBlackPixels, + 'Canvas should have rendered non-black pixels from color bars' + ).toBe(true); + + // Assert console errors before teardown. + const unexpected = collector.getUnexpected(MOQ_BENIGN_PATTERNS); + expect(unexpected, `Unexpected console errors: ${unexpected.join('; ')}`).toHaveLength(0); + collector.stop(); + + // Disconnect. + const disconnectButton = page.getByRole('button', { name: /^Disconnect$/i }).first(); + await expect(disconnectButton).toBeVisible(); + await disconnectButton.click(); + + await expect(disconnected).toBeVisible({ timeout: 10_000 }); + } else { + test.skip(true, 'MoQ WebTransport connection could not be established in this environment'); + } + + // Destroy session via UI. + const destroyButton = page.getByRole('button', { + name: /Destroy Session/i, + }); + await expect(destroyButton).toBeVisible(); + await destroyButton.click(); + + const confirmModal = page.getByTestId('confirm-modal'); + await expect(confirmModal).toBeVisible(); + await confirmModal.getByRole('button', { name: /Destroy Session/i }).click(); + + await expect(page.getByRole('button', { name: /Create Session/i })).toBeVisible({ + timeout: 15_000, + }); + + sessionId = null; + }); + + test.afterEach(async ({ baseURL }) => { + if (sessionId) { + try { + const apiContext = await request.newContext({ + baseURL: baseURL!, + extraHTTPHeaders: getAuthHeaders(), + }); + await apiContext.delete(`/api/v1/sessions/${sessionId}`); + await apiContext.dispose(); + } catch { + // Best-effort cleanup; ignore errors. + } + sessionId = null; + } + }); +}); + +test.describe('Stream View - Webcam PiP Pipeline', () => { + let collector: ConsoleErrorCollector; + let sessionId: string | null = null; + + test.beforeEach(async ({ page }) => { + collector = createConsoleErrorCollector(page); + await installAudioContextTracker(page); + await page.goto('/stream'); + await ensureLoggedIn(page); + if (!page.url().includes('/stream')) { + await page.goto('/stream'); + } + await expect(page.getByTestId('stream-view')).toBeVisible(); + }); + + test('creates webcam PiP session, connects with audio+video, verifies stream is received', async ({ + page, + baseURL, + }) => { + test.setTimeout(90_000); + + // Check MoQ gateway availability; skip if not configured. + const configResponse = await page.request.get(`${baseURL}/api/v1/config`); + if (configResponse.ok()) { + const config = (await configResponse.json()) as { + moq_gateway_url?: string | null; + }; + if (!config.moq_gateway_url) { + test.skip(true, 'MoQ gateway not configured on this server'); + } + } + + // Select the webcam PiP template. + const templateCard = page.getByText('Webcam PiP (MoQ Stream)', { + exact: true, + }); + await expect(templateCard).toBeVisible({ timeout: 10_000 }); + await templateCard.click(); + + // Create session. + const createButton = page.getByRole('button', { name: /Create Session/i }); + await expect(createButton).toBeEnabled({ timeout: 5_000 }); + await createButton.click(); + + const activeBadge = page.getByText('Session Active'); + await expect(activeBadge).toBeVisible({ timeout: 15_000 }); + + // Extract session ID for cleanup. + const sessionIdText = await page.getByText(/Session ID:/).textContent(); + sessionId = sessionIdText?.replace(/Session ID:\s*/, '').trim() ?? null; + + // Wait for MoQ connection (auto-connect or manual). + const connected = page.getByText('Relay: connected'); + const disconnected = page.getByText('Disconnected'); + const connectButton = page.getByRole('button', { + name: /Connect & Stream/i, + }); + + await expect(connected.or(connectButton)).toBeVisible({ timeout: 20_000 }); + + const isConnected = await connected.isVisible(); + if (!isConnected) { + await expect(connectButton).toBeEnabled({ timeout: 5_000 }); + await connectButton.click(); + await expect(connected.or(disconnected)).toBeVisible({ timeout: 20_000 }); + } + + const finalConnected = await connected.isVisible(); + if (finalConnected) { + // Wait for the watch path to go live. + await expect(page.getByText(/Watch: live/)).toBeVisible({ + timeout: 15_000, + }); + + // Scroll the canvas into view so the IntersectionObserver fires and + // the video decoder subscribes to the video/data track. + const canvas = page.locator('canvas'); + await expect(canvas).toBeVisible({ timeout: 5_000 }); + await canvas.scrollIntoViewIfNeeded(); + + // Give the pipeline time to process audio+video and render output. + await page.waitForTimeout(4_000); + + // Verify canvas is rendering (composited webcam PiP over colorbars). + const canvasState = await verifyCanvasRendering(page); + expect(canvasState.found, 'Canvas element not found on page').toBe(true); + expect(canvasState.width, 'Canvas has no width').toBeGreaterThan(0); + expect(canvasState.height, 'Canvas has no height').toBeGreaterThan(0); + expect( + canvasState.hasNonBlackPixels, + 'Canvas should have rendered non-black pixels from composited video' + ).toBe(true); + + // Verify audio is being decoded and played (gain-filtered loopback). + const audioState = await verifyAudioContextActive(page); + expect( + audioState.running, + 'Expected at least one running AudioContext for audio playback' + ).toBeGreaterThan(0); + expect(audioState.maxCurrentTime, 'AudioContext should have advanced').toBeGreaterThan(0); + + // Assert console errors before teardown. + const unexpected = collector.getUnexpected(MOQ_BENIGN_PATTERNS); + expect(unexpected, `Unexpected console errors: ${unexpected.join('; ')}`).toHaveLength(0); + collector.stop(); + + // Disconnect. + const disconnectButton = page.getByRole('button', { name: /^Disconnect$/i }).first(); + await expect(disconnectButton).toBeVisible(); + await disconnectButton.click(); + + await expect(disconnected).toBeVisible({ timeout: 10_000 }); + } else { + test.skip(true, 'MoQ WebTransport connection could not be established in this environment'); + } + + // Destroy session via UI. + const destroyButton = page.getByRole('button', { + name: /Destroy Session/i, + }); + await expect(destroyButton).toBeVisible(); + await destroyButton.click(); + + const confirmModal = page.getByTestId('confirm-modal'); + await expect(confirmModal).toBeVisible(); + await confirmModal.getByRole('button', { name: /Destroy Session/i }).click(); + + await expect(page.getByRole('button', { name: /Create Session/i })).toBeVisible({ + timeout: 15_000, + }); + + sessionId = null; + }); + + test.afterEach(async ({ baseURL }) => { + if (sessionId) { + try { + const apiContext = await request.newContext({ + baseURL: baseURL!, + extraHTTPHeaders: getAuthHeaders(), + }); + await apiContext.delete(`/api/v1/sessions/${sessionId}`); + await apiContext.dispose(); + } catch { + // Best-effort cleanup; ignore errors. + } + sessionId = null; + } + }); +}); \ No newline at end of file diff --git a/e2e/tests/test-helpers.ts b/e2e/tests/test-helpers.ts index aee7fc6c..87eeb0f5 100644 --- a/e2e/tests/test-helpers.ts +++ b/e2e/tests/test-helpers.ts @@ -169,3 +169,116 @@ export async function verifyAudioContextActive(page: Page): Promise<{ }; }); } + +/** + * Run inside the browser to verify a `