fix(qwen3-vl): per-segment mRoPE + vision under CP + THD packing#1308
fix(qwen3-vl): per-segment mRoPE + vision under CP + THD packing#1308Zhichenzzz wants to merge 9 commits into
Conversation
) Follow-up to #1272, which handled non-CP packed mRoPE and left CP as a logged dense fallback. Under context parallelism miles shards the THD row with the load-balanced zigzag layout (slice_with_cp), so the model sees only this ranks
…rg bug megatron-bridge 0.5.0 calls AllGatherVisionEmbeddings.apply(..., cp_group=...) in the Qwen3-VL vision_dp_when_cp path, but torch.autograd.Function.apply rejects keyword arguments. Install a shim (at import, alongside the mRoPE patch) whose .apply accepts cp_group as a kwarg and forwards it positionally. One of the blockers for end-to-end Qwen3-VL CP training (see #1296).
…1296) Completes the CP+packing path for Qwen3-VL. The bridge forward assumes the FULL unsharded input and re-shards internally (preprocess_packed_seqs) + inserts vision embeddings against a full mask, but miles pre-shards the THD row (slice_with_cp). That double-sharded and mismatched the vision mask vs the full vision-tower output. When the input is already CP-sharded (cu_seqlens_q[-1] == cp_size * local_len): - preprocess_packed_seqs is wrapped to be an identity that returns miles full-cu packed_seq_params, so the bridge does not re-split the already-local data while CP attention still sees the full cu_seqlens; - a per-rank vision-embed selector (select_local_vision_embeds) maps this rank local vision tokens to the matching slice of the full vision-tower output (and deepstack), via a reconstructed full row + zigzag local->full position map; - the bridge model.py gains a no-op _miles_select_local_vision_embeds hook at the vision/deepstack insertion sites that miles overrides at import. Plus calculate_per_token_loss wired into the bridge provider (CP asserts on it). Validated end-to-end: Qwen3-VL-2B geo3k, CP=2 TP=4 8xH200, THD packed, bridge mode, train_rollout_logprob_abs_diff = 0.0141 at step 0 (healthy, matches non-CP 0.011-0.016), rollout raw_reward 0.42, no crashes. The bridge-side hook is a 4-line change delivered separately as a patch (belongs upstream in megatron-bridge). Fixes #1296
There was a problem hiding this comment.
Code Review
This pull request introduces support for Qwen3-VL context parallelism (CP) and THD packed mRoPE reconstruction. It adds patches to handle CP-local vision embeddings, bypasses redundant re-sharding in preprocess_packed_seqs, and implements a shim for AllGatherVisionEmbeddings to accept cp_group as a keyword argument. Additionally, a comprehensive CPU unit test suite is added to verify the correctness of the zigzag reconstruction. The reviewer feedback suggests improving the robustness of the patches by having the _AllGatherVisionEmbeddingsKwargShim inherit from the original class to preserve its type hierarchy, and using *args and **kwargs in the preprocess_packed_seqs wrapper to guard against future signature changes.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| class _AllGatherVisionEmbeddingsKwargShim: | ||
| _miles_kwarg_shim = True | ||
|
|
||
| @staticmethod | ||
| def apply(input, seqlens_on_cp_ranks, cp_group=None): | ||
| return orig.apply(input, seqlens_on_cp_ranks, cp_group) |
There was a problem hiding this comment.
To ensure that _AllGatherVisionEmbeddingsKwargShim behaves identically to the original AllGatherVisionEmbeddings class (e.g., preserving class attributes, static methods, or satisfying issubclass / isinstance checks in downstream code), it is safer to have the shim inherit from orig instead of being a completely separate, plain class.
| class _AllGatherVisionEmbeddingsKwargShim: | |
| _miles_kwarg_shim = True | |
| @staticmethod | |
| def apply(input, seqlens_on_cp_ranks, cp_group=None): | |
| return orig.apply(input, seqlens_on_cp_ranks, cp_group) | |
| class _AllGatherVisionEmbeddingsKwargShim(orig): | |
| _miles_kwarg_shim = True | |
| @staticmethod | |
| def apply(input, seqlens_on_cp_ranks, cp_group=None): | |
| return orig.apply(input, seqlens_on_cp_ranks, cp_group) |
| def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None): | ||
| ctx = getattr(_tls, "cp_local", None) | ||
| if ctx is not None: | ||
| # already-local CP path: do not re-shard; return the data unchanged together with | ||
| # miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention | ||
| # uses the packed_seq_params passed into forward, which already has the full cu). | ||
| return input_ids, ctx["psp"] | ||
| return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection) |
There was a problem hiding this comment.
To make the monkeypatched wrapped function more robust against future signature changes in preprocess_packed_seqs (e.g., if the bridge library adds or reorders arguments), it is highly recommended to use *args and **kwargs and extract input_ids dynamically. This prevents potential TypeError exceptions due to signature mismatches.
| def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None): | |
| ctx = getattr(_tls, "cp_local", None) | |
| if ctx is not None: | |
| # already-local CP path: do not re-shard; return the data unchanged together with | |
| # miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention | |
| # uses the packed_seq_params passed into forward, which already has the full cu). | |
| return input_ids, ctx["psp"] | |
| return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection) | |
| def wrapped(*args, **kwargs): | |
| ctx = getattr(_tls, "cp_local", None) | |
| if ctx is not None: | |
| # already-local CP path: do not re-shard; return the data unchanged together with | |
| # miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention | |
| # uses the packed_seq_params passed into forward, which already has the full cu). | |
| input_ids = kwargs.get("input_ids") if "input_ids" in kwargs else args[0] | |
| return input_ids, ctx["psp"] | |
| return orig(*args, **kwargs) |
- rename the free helpers' first parameter self -> model (they are not methods; the model instance is passed in from the patched forward) - forward *args/**kwargs in the preprocess_packed_seqs identity wrapper instead of hard-coding the upstream signature - warn once at patch install when the running megatron-bridge lacks the _miles_select_local_vision_embeds hook, instead of silently mis-placing vision embeddings under CP (points at the matching Megatron-Bridge patch)
Megatron-Bridge PR #9 renamed its extension point to the vendor-neutral select_cp_local_vision_embeds; look that up first and keep the old _miles_-prefixed name as a fallback for older patched bridges.
Megatron-Bridge now selects CP-local vision embeddings natively inside Qwen3VLModel (no override hook), so remove select_local_vision_embeds, the local->full position mapping it needed, and the hook installation; keep a warning when the running bridge lacks the native support. The cp_local context now only carries packed_seq_params for the preprocess_packed_seqs identity wrapper.
Follow-up to #1272 (stacked on its branch, so this diff shows only the #1296 changes — merge after #1272).
What
Makes Qwen3-VL train end-to-end under context parallelism + THD sequence packing in bridge mode.
_build_packed_positionsall-gathers the per-rank rows, de-interleaves to the full row (_reassemble_full_row, unit-tested intests/fast/test_qwen3_vl_cp_mrope.py), rebuilds per-segment MRoPE, and re-slices into this rank's zigzag layout.preprocess_packed_seqsis wrapped to an identity that returns miles' full-cupacked_seq_params(CP attention still sees the full cu; the data isn't re-split).select_local_vision_embedsmaps each rank's local vision tokens to the matching slice of the full vision-tower output (and deepstack). Cooperates with a small hook in megatron-bridge (separate PR to radixark/Megatron-Bridge).calculate_per_token_lossinto the bridge provider (Qwen3-VL asserts on it under CP) and a defensiveAllGatherVisionEmbeddings.applykwarg shim.Validation (e2e)
Qwen3-VL-2B geo3k, CP=2 TP=4, 8×H200, THD packed, bridge mode: stable over 3 steps,
train_rollout_logprob_abs_diff0.0141 → 0.0146 (healthy, == non-CP 0.011–0.016),rollout/raw_reward~0.4, no crashes.Depends on
Fixes #1296
Update: cleanup pass + re-validation
self→model(they are free functions, not methods)preprocess_packed_seqsidentity wrapper forwards*args/**kwargsinstead of hard-coding the upstream signature_miles_select_local_vision_embedshook (instead of silently mis-placing vision embeddings under CP); points at the matching Megatron-Bridge patchRe-validated after cleanup: unit tests 6/6; Qwen3-VL-2B CP2 TP4 THD geo3k RL —
train_rollout_logprob_abs_diff0.0127–0.0131 (same healthy band as the original validation), coherent rollouts, no hook warning with the bridge patch installed.Note on stacking: this branch contains #1272; the cleanup lives here because this PR owns the final shape of
qwen3_vl_packed_mrope.py.Update: vision-embed plumbing removed (−53 lines)
Megatron-Bridge PR #9 now selects CP-local vision embeddings natively inside
Qwen3VLModel(no override hook), so this PR no longer carriesselect_local_vision_embeds, the local→full position mapping, or any hook installation. What remains on the miles side: per-segment mRoPE position reconstruction and thepreprocess_packed_seqsidentity wrapper (both still needed — they are about positions, not embeddings), plus a warning when the running bridge lacks the native support.Re-validated end-to-end after the removal (Qwen3-VL-2B, CP2 TP2, THD, geo3k RL):
train_rollout_logprob_abs_diff0.0130, coherent rollouts, no warnings, no crashes.