Skip to content

decouple the transformer and decoder#940

Merged
helloyongyang merged 4 commits intoModelTC:mainfrom
zhtshr:zht_dev
Mar 12, 2026
Merged

decouple the transformer and decoder#940
helloyongyang merged 4 commits intoModelTC:mainfrom
zhtshr:zht_dev

Conversation

@zhtshr
Copy link
Contributor

@zhtshr zhtshr commented Mar 11, 2026

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant architectural change by decoupling the transformer and decoder into separate services. The data processing pipeline now operates in two distinct phases, managed by an enhanced DataManager. Phase 1 handles data transfer from the EncoderService to the TransformerService, while Phase 2 manages the transfer of processed latents from the TransformerService to the newly introduced DecoderService for final video generation. This modularization improves the system's flexibility and scalability for distributed inference.

Highlights

  • Architectural Decoupling: The transformer and decoder components have been decoupled into distinct services, introducing a two-phase disaggregation model for data processing.
  • New Disaggregation Phases and Modes: A DisaggregationPhase enum was added to manage distinct processing stages (Phase1, Phase2), and the DisaggregationMode enum was extended with a DECODE option.
  • DataManager Refactoring: The DataManager class was refactored to incorporate DisaggregationPhase, enabling phase-specific thread management for encoding, transforming, and decoding, and renaming existing thread-starting methods while adding new ones for Phase 2.
  • Introduction of DecoderService: A new DecoderService class was implemented to specifically handle VAE decoding and video generation, offloading this responsibility from the TransformerService.
  • Service Updates: Both EncoderService and TransformerService were updated to align with the new two-phase architecture, including changes to buffer allocation, data transfer logic, and metadata handling.
  • Protocol Streamlining: protocol.py was adjusted to simplify data allocation requests by replacing a generic config with specific buffer_sizes and removing session_id from RemoteBuffer.
  • Buffer Size Estimation: A new utility function, estimate_transformer_buffer_sizes, was added to utils.py for pre-calculating buffer requirements for the transformer's output.
  • Example Script Updates: Example service scripts (wan_i2v_service.py, wan_t2v_service.py) were updated to orchestrate the new three-service (Encoder, Transformer, Decoder) workflow.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lightx2v/disagg/conn.py
    • Added DisaggregationPhase enum with NULL, PHASE1, and PHASE2 states.
    • Extended DisaggregationMode enum with a DECODE state.
    • Modified DataManager's __init__ to accept disaggregation_phase and disaggregation_mode, and to conditionally start threads based on these phases and modes.
    • Renamed start_encode_thread to start_phase1_encode_thread.
    • Renamed start_transformer_thread to start_phase1_transformer_thread.
    • Introduced start_phase2_transformer_thread and start_phase2_decode_thread methods.
    • Updated enqueue_request and check_status logic to consider disaggregation_phase.
  • lightx2v/disagg/examples/wan_i2v_service.py
    • Imported threading and DecoderService.
    • Added decoder_engine_rank to the configuration.
    • Defined run_decoder function to initialize and process the DecoderService.
    • Integrated decoder_thread into the service startup and join sequence.
  • lightx2v/disagg/examples/wan_t2v_service.py
    • Imported threading and DecoderService.
    • Updated model and save paths to /root/zht/LightX2V/.
    • Added decoder_engine_rank to the configuration.
    • Defined run_decoder function to initialize and process the DecoderService.
    • Integrated decoder_thread into the service startup and join sequence.
  • lightx2v/disagg/mooncake.py
    • Removed logger.error calls before raising RuntimeError in register and deregister methods.
  • lightx2v/disagg/protocol.py
    • Modified AllocationRequest to use buffer_sizes (List[int]) instead of a generic config (Dict[str, Any]).
    • Removed session_id from the RemoteBuffer dataclass.
  • lightx2v/disagg/services/decoder.py
    • Added new file decoder.py implementing DecoderService.
    • DecoderService initializes DataManager for Phase 2 (DECODE mode) and DataReceiver.
    • Loads VAE decoder models.
    • Allocates memory for receiving latents and metadata.
    • Processes received latents, decodes them using VAE, and saves the resulting video.
    • Includes utility functions for buffer viewing and SHA256 hashing.
    • Implements release_memory to deregister RDMA buffers.
  • lightx2v/disagg/services/encoder.py
    • Imported DisaggregationPhase, AllocationRequest, and RemoteBuffer.
    • Added transformer_engine_rank and decoder_engine_rank to __init__.
    • Refactored __init__ to use AllocationRequest and DisaggregationPhase.PHASE1 for DataManager initialization.
    • Renamed alloc_bufs to alloc_memory and updated its signature to accept AllocationRequest and return MemoryHandle.
    • Modified process method to use _buffer_view more efficiently and include dtype and hash in metadata.
    • Removed hardcoded buffer size calculations within process and relied on _buffer_view and _sha256_tensor.
  • lightx2v/disagg/services/transformer.py
    • Imported DataSender, DisaggregationPhase, and removed F (functional), load_wan_vae_decoder, save_to_video, wan_vae_to_comfy.
    • Removed vae_decoder attribute.
    • Introduced rdma_buffer1 and rdma_buffer2 for two-phase buffer management.
    • Added encoder_engine_rank, transformer_engine_rank, decoder_engine_rank to __init__.
    • Refactored __init__ to initialize two DataManager instances: data_mgr1 (for receiving from encoder, Phase 1) and data_mgr2 (for sending to decoder, Phase 2).
    • Modified alloc_memory to accept DisaggregationPhase and manage buffers for both phases.
    • Removed _align_vae_to_latents utility function.
    • Updated process to receive data from data_receiver (Phase 1), perform transformer inference, and then send latents and metadata to the DecoderService via data_sender (Phase 2).
    • Removed VAE decoding and video saving logic from this service.
    • Updated release_memory to deregister buffers from both data_mgr1 and data_mgr2.
  • lightx2v/disagg/utils.py
    • Added estimate_transformer_buffer_sizes function to calculate buffer requirements for transformer outputs (latents and metadata).
Activity
  • The pull request introduces a significant architectural change by decoupling the transformer and decoder.
  • New DisaggregationPhase and DECODE mode were added to manage the new two-phase data flow.
  • A new DecoderService was implemented to handle VAE decoding and video saving.
  • Existing DataManager, EncoderService, and TransformerService were refactored to support this new phased approach, including changes to buffer allocation and data transfer logic.
  • Example scripts were updated to reflect the new three-service orchestration.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to decouple the transformer and decoder stages through a two-phase data transfer model, which is a significant architectural improvement. However, a critical vulnerability has been identified: the DataManager lacks authentication and proper validation of remote endpoints and memory addresses, leading to potential Server-Side Request Forgery (SSRF) and unauthorized remote memory corruption via RDMA. Additionally, the review highlights areas for improvement concerning significant code duplication within the DataManager class and potential race conditions in thread startup sequences in example service files.

Comment on lines +249 to +294
) = self.server_socket.recv_multipart()
if bootstrap_room.decode("ascii") == "None":
continue
endpoint = endpoint.decode("ascii")
mooncake_session_id = mooncake_session_id.decode("ascii")
bootstrap_room = int(bootstrap_room.decode("ascii"))
decode_ptrs = list(struct.unpack(f"{len(decode_ptrs)//8}Q", decode_ptrs))
logger.info(
"Transformer received ZMQ: endpoint=%s session_id=%s room=%s decode_ptrs=%s",
endpoint,
mooncake_session_id,
bootstrap_room,
decode_ptrs,
)
self.waiting_pool[bootstrap_room] = (
endpoint,
mooncake_session_id,
decode_ptrs,
)
self.transfer_event.set()

threading.Thread(target=transformer_thread).start()

def transfer_thread():
while True:
self.transfer_event.wait()
self.transfer_event.clear()
bootstrap_room_ready = self.request_pool.keys()
bootstrap_room_request = self.waiting_pool.keys()
for room in list(bootstrap_room_request):
if room not in list(bootstrap_room_ready):
continue
status = DataPoll.Transferring
self.request_status[room] = status
(
endpoint,
mooncake_session_id,
decode_ptrs,
) = self.waiting_pool.pop(room)
self.sync_status_to_transformer_endpoint(endpoint, room)
transformer_data_ptrs = self.request_pool.pop(room)
ret = self.send_data(
mooncake_session_id,
transformer_data_ptrs,
decode_ptrs,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The DataManager class receives a mooncake_session_id and remote memory addresses (decode_ptrs) from an unauthenticated ZMQ PULL socket (line 249). These values are passed directly to the Mooncake transfer engine's transfer_sync_write method via send_data (lines 290-294). This allows an unauthenticated attacker to trigger RDMA writes to arbitrary memory locations on any node whose Mooncake session ID they can guess or obtain. This could lead to unauthorized data modification or memory corruption on other nodes in the cluster.

Comment on lines +249 to +288
) = self.server_socket.recv_multipart()
if bootstrap_room.decode("ascii") == "None":
continue
endpoint = endpoint.decode("ascii")
mooncake_session_id = mooncake_session_id.decode("ascii")
bootstrap_room = int(bootstrap_room.decode("ascii"))
decode_ptrs = list(struct.unpack(f"{len(decode_ptrs)//8}Q", decode_ptrs))
logger.info(
"Transformer received ZMQ: endpoint=%s session_id=%s room=%s decode_ptrs=%s",
endpoint,
mooncake_session_id,
bootstrap_room,
decode_ptrs,
)
self.waiting_pool[bootstrap_room] = (
endpoint,
mooncake_session_id,
decode_ptrs,
)
self.transfer_event.set()

threading.Thread(target=transformer_thread).start()

def transfer_thread():
while True:
self.transfer_event.wait()
self.transfer_event.clear()
bootstrap_room_ready = self.request_pool.keys()
bootstrap_room_request = self.waiting_pool.keys()
for room in list(bootstrap_room_request):
if room not in list(bootstrap_room_ready):
continue
status = DataPoll.Transferring
self.request_status[room] = status
(
endpoint,
mooncake_session_id,
decode_ptrs,
) = self.waiting_pool.pop(room)
self.sync_status_to_transformer_endpoint(endpoint, room)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The DataManager class is vulnerable to Server-Side Request Forgery (SSRF) because it uses an unauthenticated endpoint string from a ZMQ PULL socket (line 249) to initiate connections to arbitrary hosts in sync_status_to_transformer_endpoint (line 288). This allows attackers to probe internal networks or exhaust resources via @cache. This vulnerability is present within duplicated thread starter methods like start_phase2_transformer_thread and start_phase1_encode_thread. Consolidating these into generic private methods, such as _start_sender_thread and _start_receiver_thread, would centralize the logic for implementing crucial authentication and validation for the endpoint.

Comment on lines +92 to +111
if self.disaggregation_phase == DisaggregationPhase.PHASE1:
if self.disaggregation_mode == DisaggregationMode.ENCODE:
self.waiting_pool: WaitingPoolType = {}
self.transfer_event = threading.Event()
self.start_phase1_encode_thread()
elif self.disaggregation_mode == DisaggregationMode.TRANSFORMER:
self.start_phase1_transformer_thread()
else:
raise ValueError(f"Unsupported DisaggregationMode in this phase: {self.disaggregation_phase}, {self.disaggregation_mode}")
elif self.disaggregation_phase == DisaggregationPhase.PHASE2:
if self.disaggregation_mode == DisaggregationMode.TRANSFORMER:
self.waiting_pool: WaitingPoolType = {}
self.transfer_event = threading.Event()
self.start_phase2_transformer_thread() # TODO: start_p2_transformer_thread
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_phase2_decode_thread() # TODO: start_p2_decode_thread
else:
raise ValueError(f"Unsupported DisaggregationMode in this phase: {self.disaggregation_phase}, {self.disaggregation_mode}")
else:
raise ValueError(f"Unsupported DisaggregationMode: {self.disaggregation_mode}")
raise ValueError(f"Unsupported DisaggregationPhase: {self.disaggregation_phase}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of self.waiting_pool and self.transfer_event is duplicated for sender modes (PHASE1/ENCODE and PHASE2/TRANSFORMER). This could be consolidated by first determining if the current configuration is a 'sender' and then performing the initialization once to improve maintainability and reduce code repetition.

@helloyongyang helloyongyang merged commit b74f415 into ModelTC:main Mar 12, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants