Skip to content

fix: normalize kv cache indices to Python ints in krea_realtime_video#682

Open
livepeer-tessa wants to merge 2 commits intomainfrom
fix/679-krea-kv-cache-index-tensor-pollution
Open

fix: normalize kv cache indices to Python ints in krea_realtime_video#682
livepeer-tessa wants to merge 2 commits intomainfrom
fix/679-krea-kv-cache-index-tensor-pollution

Conversation

@livepeer-tessa
Copy link
Contributor

Summary

Fixes #679

Two related errors were appearing during chunk processing on :

  1. FX symbolic trace of dynamo-optimized function
  2. _dispatch_keys TypeError: incompatible function arguments

Root Cause

After a cache reset, initialize_kv_cache() stores torch.tensor([0], dtype=torch.long) in the global_end_index and local_end_index slots. The krea_realtime_video CausalWanSelfAttention read these values without converting to Python ints, unlike every other pipeline (longlive, memflow, streamdiffusionv2 — all use .item()).

This caused all downstream arithmetic to produce tensors:

local_end_index = kv_cache["local_end_index"] + current_end - kv_cache["global_end_index"]
# ↑ tensor on first chunk after cache reset
cache_current_block_start = cache_len - frame_seqlen * self.num_frame_per_block  
# ↑ also a tensor

When cache_current_block_start (a tensor) was captured in score_mod and passed to torch.compile(flex_attention, dynamic=False), flex_attention re-traced score_mod on every chunk because the captured tensor object identity changed. The FX tracer then hit the already-compiled flex_attention, triggering the "FX symbolic trace" error. The _dispatch_keys TypeError followed from FakeTensors colliding with real CUDA tensors during that re-trace.

Fix

Two changes to CausalWanSelfAttention.forward():

1. Normalize cache indices to Python ints at the start of the caching block:

cache_global_end: int = int(kv_cache["global_end_index"])
cache_local_end: int = int(kv_cache["local_end_index"])

int() safely handles both Python ints (after the first chunk) and single-element tensors (first chunk after cache reset).

2. Use Python scalar literals in score_mod instead of freshly-created CUDA tensors:

_fs: int = frame_seqlen
_ccbs: int = cache_current_block_start  # now guaranteed int from fix #1
_ls: float = log_scale

def score_mod(score, b_idx, h_idx, q_idx, kv_idx):
    return torch.where((kv_idx >= _fs) & (kv_idx < _ccbs), score + _ls, score)

Python scalars become stable graph constants — torch.compile captures them once and never re-traces, regardless of their value changing between chunks.

Testing

The fix can be verified by:

  1. Loading krea-realtime-video on a GPU worker
  2. Streaming continuously for several minutes
  3. Confirming absence of the FX tracing / _dispatch_keys errors in logs

livepeer-robot added 2 commits March 13, 2026 06:41
On fal.ai GPU-H100 workers torch.cuda.is_available() can return True
(CUDA runtime is installed) while actual GPU access later fails with
'No CUDA GPUs are available'.  This happens when CUDA_VISIBLE_DEVICES
is set to an empty string or an invalid MIG UUID, or when the CUDA
context has not yet been initialised and lazy init fails.

Plugin pipelines like flashvsr are disproportionately affected because
their __init__ immediately allocates CUDA tensors (model loads + warmup
pass), unlike built-in pipelines that share an already-established CUDA
context.

Changes:
- pipeline_manager: add _assert_cuda_accessible() that forces lazy CUDA
  initialisation via a test tensor allocation, reporting device_count
  and CUDA_VISIBLE_DEVICES on failure.  Called before every plugin
  pipeline load so the error surfaces early with actionable context
  rather than inside the plugin's __init__.
- fal_app: log CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES at
  startup so future failures can be correlated with the worker
  environment at a glance.

Fixes #675

Signed-off-by: livepeer-robot <robot@livepeer.org>
Fixes #679

After a cache reset, initialize_kv_cache() stores
torch.tensor([0], dtype=torch.long) in the global_end_index and
local_end_index slots of the KV cache. The krea_realtime_video
CausalWanSelfAttention forward pass read these values directly without
converting to Python ints, causing all subsequent arithmetic
(local_end_index, cache_current_block_start, etc.) to also produce
tensors.

When cache_current_block_start was captured in score_mod as a closure
variable and passed to torch.compile(flex_attention, dynamic=False),
flex_attention attempted to re-trace score_mod on every chunk because
the captured tensor *object* identity changed each call. The FX tracer
then hit the already-compiled flex_attention, triggering:

  'Detected that you are using FX to symbolically trace a
   dynamo-optimized function. This is not supported at the moment.'

The _dispatch_keys TypeError followed from FakeTensors (used during
trace) colliding with real CUDA tensors captured in the closure.

Fix: extract cache_global_end and cache_local_end as Python ints using
int() at the top of the caching block. int() safely handles both Python
ints (already in the cache after the first chunk) and single-element
torch.Tensors (present on the first chunk after a cache reset).

Also replace the tensor-based score_mod constants (frame_seqlen_tensor,
cache_current_block_start_tensor, log_scale_tensor) with Python scalar
literals (_fs, _ccbs, _ls) that become stable graph constants, avoiding
both the FX re-trace and the _dispatch_keys collision. Other pipelines
(longlive, memflow, streamdiffusionv2) already use .item() for cache
index reads for the same reason.

Signed-off-by: Tessa <tessa@livepeer.org>
Signed-off-by: livepeer-robot <robot@livepeer.org>
@coderabbitai
Copy link

coderabbitai bot commented Mar 13, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: c9deaf97-0121-4ccf-b69f-072c1ba2b1e7

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/679-krea-kv-cache-index-tensor-pollution
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can enforce grammar and style rules using `languagetool`.

Configure the reviews.tools.languagetool setting to enable/disable rules and categories. Refer to the LanguageTool Community to learn more.

@github-actions
Copy link
Contributor

🚀 fal.ai Preview Deployment

App ID daydream/scope-pr-682--preview
WebSocket wss://fal.run/daydream/scope-pr-682--preview/ws
Commit a26a3d5

Testing

Connect to this preview deployment by running this on your branch:

uv run build && SCOPE_CLOUD_APP_ID="daydream/scope-pr-682--preview/ws" uv run daydream-scope

🧪 E2E tests will run automatically against this deployment.

@github-actions
Copy link
Contributor

✅ E2E Tests passed

Status passed
fal App daydream/scope-pr-682--preview
Run View logs

Test Artifacts

Check the workflow run for screenshots.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

krea-realtime-video: FX symbolic trace + _dispatch_keys errors during chunk processing

1 participant