Skip to content

fix(qwen3-vl): per-segment mRoPE + vision under CP + THD packing#1308

Open
Zhichenzzz wants to merge 9 commits into
zhichen/qwen3-vl-thd-miles-hijackfrom
fix/1296-qwen3vl-cp-mrope
Open

fix(qwen3-vl): per-segment mRoPE + vision under CP + THD packing#1308
Zhichenzzz wants to merge 9 commits into
zhichen/qwen3-vl-thd-miles-hijackfrom
fix/1296-qwen3vl-cp-mrope

Conversation

@Zhichenzzz

@Zhichenzzz Zhichenzzz commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Follow-up to #1272 (stacked on its branch, so this diff shows only the #1296 changes — merge after #1272).

What

Makes Qwen3-VL train end-to-end under context parallelism + THD sequence packing in bridge mode.

  1. Per-segment mRoPE under CP — when the THD row is CP-sharded (zigzag), _build_packed_positions all-gathers the per-rank rows, de-interleaves to the full row (_reassemble_full_row, unit-tested in tests/fast/test_qwen3_vl_cp_mrope.py), rebuilds per-segment MRoPE, and re-slices into this rank's zigzag layout.
  2. Don't double-shard — when the input is already CP-local, the bridge's internal preprocess_packed_seqs is wrapped to an identity that returns miles' full-cu packed_seq_params (CP attention still sees the full cu; the data isn't re-split).
  3. CP-local vision embedsselect_local_vision_embeds maps each rank's local vision tokens to the matching slice of the full vision-tower output (and deepstack). Cooperates with a small hook in megatron-bridge (separate PR to radixark/Megatron-Bridge).
  4. Wires calculate_per_token_loss into the bridge provider (Qwen3-VL asserts on it under CP) and a defensive AllGatherVisionEmbeddings.apply kwarg shim.

Validation (e2e)

Qwen3-VL-2B geo3k, CP=2 TP=4, 8×H200, THD packed, bridge mode: stable over 3 steps, train_rollout_logprob_abs_diff 0.0141 → 0.0146 (healthy, == non-CP 0.011–0.016), rollout/raw_reward ~0.4, no crashes.

Depends on

Fixes #1296


Update: cleanup pass + re-validation

  • helpers' first parameter renamed selfmodel (they are free functions, not methods)
  • the preprocess_packed_seqs identity wrapper forwards *args/**kwargs instead of hard-coding the upstream signature
  • the patch warns once at install when the running megatron-bridge lacks the _miles_select_local_vision_embeds hook (instead of silently mis-placing vision embeddings under CP); points at the matching Megatron-Bridge patch

Re-validated after cleanup: unit tests 6/6; Qwen3-VL-2B CP2 TP4 THD geo3k RL — train_rollout_logprob_abs_diff 0.0127–0.0131 (same healthy band as the original validation), coherent rollouts, no hook warning with the bridge patch installed.

Note on stacking: this branch contains #1272; the cleanup lives here because this PR owns the final shape of qwen3_vl_packed_mrope.py.


Update: vision-embed plumbing removed (−53 lines)

Megatron-Bridge PR #9 now selects CP-local vision embeddings natively inside Qwen3VLModel (no override hook), so this PR no longer carries select_local_vision_embeds, the local→full position mapping, or any hook installation. What remains on the miles side: per-segment mRoPE position reconstruction and the preprocess_packed_seqs identity wrapper (both still needed — they are about positions, not embeddings), plus a warning when the running bridge lacks the native support.

Re-validated end-to-end after the removal (Qwen3-VL-2B, CP2 TP2, THD, geo3k RL): train_rollout_logprob_abs_diff 0.0130, coherent rollouts, no warnings, no crashes.

)

Follow-up to #1272, which handled non-CP packed mRoPE and left CP as a logged
dense fallback. Under context parallelism miles shards the THD row with the
load-balanced zigzag layout (slice_with_cp), so the model sees only this ranks
…rg bug

megatron-bridge 0.5.0 calls AllGatherVisionEmbeddings.apply(..., cp_group=...) in the
Qwen3-VL vision_dp_when_cp path, but torch.autograd.Function.apply rejects keyword
arguments. Install a shim (at import, alongside the mRoPE patch) whose .apply accepts
cp_group as a kwarg and forwards it positionally. One of the blockers for end-to-end
Qwen3-VL CP training (see #1296).
…1296)

Completes the CP+packing path for Qwen3-VL. The bridge forward assumes the FULL
unsharded input and re-shards internally (preprocess_packed_seqs) + inserts vision
embeddings against a full mask, but miles pre-shards the THD row (slice_with_cp).
That double-sharded and mismatched the vision mask vs the full vision-tower output.

When the input is already CP-sharded (cu_seqlens_q[-1] == cp_size * local_len):
- preprocess_packed_seqs is wrapped to be an identity that returns miles full-cu
  packed_seq_params, so the bridge does not re-split the already-local data while
  CP attention still sees the full cu_seqlens;
- a per-rank vision-embed selector (select_local_vision_embeds) maps this rank local
  vision tokens to the matching slice of the full vision-tower output (and deepstack),
  via a reconstructed full row + zigzag local->full position map;
- the bridge model.py gains a no-op _miles_select_local_vision_embeds hook at the
  vision/deepstack insertion sites that miles overrides at import.
Plus calculate_per_token_loss wired into the bridge provider (CP asserts on it).

Validated end-to-end: Qwen3-VL-2B geo3k, CP=2 TP=4 8xH200, THD packed, bridge mode,
train_rollout_logprob_abs_diff = 0.0141 at step 0 (healthy, matches non-CP 0.011-0.016),
rollout raw_reward 0.42, no crashes. The bridge-side hook is a 4-line change delivered
separately as a patch (belongs upstream in megatron-bridge).

Fixes #1296

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Qwen3-VL context parallelism (CP) and THD packed mRoPE reconstruction. It adds patches to handle CP-local vision embeddings, bypasses redundant re-sharding in preprocess_packed_seqs, and implements a shim for AllGatherVisionEmbeddings to accept cp_group as a keyword argument. Additionally, a comprehensive CPU unit test suite is added to verify the correctness of the zigzag reconstruction. The reviewer feedback suggests improving the robustness of the patches by having the _AllGatherVisionEmbeddingsKwargShim inherit from the original class to preserve its type hierarchy, and using *args and **kwargs in the preprocess_packed_seqs wrapper to guard against future signature changes.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +43 to +48
class _AllGatherVisionEmbeddingsKwargShim:
_miles_kwarg_shim = True

@staticmethod
def apply(input, seqlens_on_cp_ranks, cp_group=None):
return orig.apply(input, seqlens_on_cp_ranks, cp_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure that _AllGatherVisionEmbeddingsKwargShim behaves identically to the original AllGatherVisionEmbeddings class (e.g., preserving class attributes, static methods, or satisfying issubclass / isinstance checks in downstream code), it is safer to have the shim inherit from orig instead of being a completely separate, plain class.

Suggested change
class _AllGatherVisionEmbeddingsKwargShim:
_miles_kwarg_shim = True
@staticmethod
def apply(input, seqlens_on_cp_ranks, cp_group=None):
return orig.apply(input, seqlens_on_cp_ranks, cp_group)
class _AllGatherVisionEmbeddingsKwargShim(orig):
_miles_kwarg_shim = True
@staticmethod
def apply(input, seqlens_on_cp_ranks, cp_group=None):
return orig.apply(input, seqlens_on_cp_ranks, cp_group)

Comment on lines +132 to +139
def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None):
ctx = getattr(_tls, "cp_local", None)
if ctx is not None:
# already-local CP path: do not re-shard; return the data unchanged together with
# miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention
# uses the packed_seq_params passed into forward, which already has the full cu).
return input_ids, ctx["psp"]
return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make the monkeypatched wrapped function more robust against future signature changes in preprocess_packed_seqs (e.g., if the bridge library adds or reorders arguments), it is highly recommended to use *args and **kwargs and extract input_ids dynamically. This prevents potential TypeError exceptions due to signature mismatches.

Suggested change
def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None):
ctx = getattr(_tls, "cp_local", None)
if ctx is not None:
# already-local CP path: do not re-shard; return the data unchanged together with
# miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention
# uses the packed_seq_params passed into forward, which already has the full cu).
return input_ids, ctx["psp"]
return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection)
def wrapped(*args, **kwargs):
ctx = getattr(_tls, "cp_local", None)
if ctx is not None:
# already-local CP path: do not re-shard; return the data unchanged together with
# miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention
# uses the packed_seq_params passed into forward, which already has the full cu).
input_ids = kwargs.get("input_ids") if "input_ids" in kwargs else args[0]
return input_ids, ctx["psp"]
return orig(*args, **kwargs)

- rename the free helpers' first parameter self -> model (they are not
  methods; the model instance is passed in from the patched forward)
- forward *args/**kwargs in the preprocess_packed_seqs identity wrapper
  instead of hard-coding the upstream signature
- warn once at patch install when the running megatron-bridge lacks the
  _miles_select_local_vision_embeds hook, instead of silently
  mis-placing vision embeddings under CP (points at the matching
  Megatron-Bridge patch)
Megatron-Bridge PR #9 renamed its extension point to the vendor-neutral
select_cp_local_vision_embeds; look that up first and keep the old
_miles_-prefixed name as a fallback for older patched bridges.
Megatron-Bridge now selects CP-local vision embeddings natively inside
Qwen3VLModel (no override hook), so remove select_local_vision_embeds,
the local->full position mapping it needed, and the hook installation;
keep a warning when the running bridge lacks the native support. The
cp_local context now only carries packed_seq_params for the
preprocess_packed_seqs identity wrapper.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant