Skip to content

Add Xiaomi MiMo-V2.5 MoE model support#2791

Draft
hallerite wants to merge 1 commit into
mainfrom
feat/mimo-v2-5-model
Draft

Add Xiaomi MiMo-V2.5 MoE model support#2791
hallerite wants to merge 1 commit into
mainfrom
feat/mimo-v2-5-model

Conversation

@hallerite

Copy link
Copy Markdown
Member

Adds the text backbone of Xiaomi MiMo-V2.5 (310B MoE, 15B activated/token, 256 experts top-8) as a custom trainer model.

What's in here

  • Modeling (src/prime_rl/trainer/models/mimo_v2/): the architecture has several features outside the shared layer primitives, so attention is implemented model-locally:
    • Hybrid attention: per-layer full vs sliding-window (128) via hybrid_layer_pattern, with different KV head counts per type (4 full / 8 SWA) and per-type RoPE bases (1e7 full / 1e4 SWA, partial rotary factor 0.334).
    • Asymmetric head dims (qk 192 / v 128): FA2 requires equal head dims, so the flash path runs with V zero-padded to the qk dim and the output sliced back — mathematically exact, verified.
    • Attention sinks: SWA layers carry a frozen per-head attention_sink_bias (an extra softmax column). Flash attention can't express the sink and its returned LSE is not differentiable, so SWA layers use an exact chunked windowed implementation: each query block materializes logits only against its (window+chunk)-sized key slice — O(total·window) memory, packing-aware via cu_seqlens, fp32 softmax.
    • attention_value_scale=0.707 applied to V pre-attention, matching the reference.
    • Fused qkv_proj kept as the module layout so hub checkpoint keys map 1:1 (no shape inference needed in conversion; split sizes come from config at init).
    • MoE: noaux_tc sigmoid routing with n_group=1 degenerates to exactly TokenChoiceTopKRouter semantics (bias affects selection only, unbiased weights, top-k normalization); gate.e_score_correction_bias folds into the expert_bias buffer. No shared experts. Layer 0 dense via moe_layer_freq.
  • Conversion: dequantizes the hub's block-128 FP8 (weight_scale_inv, new dequantize_fp8_blockwise in fp8.py) to bf16, stacks per-expert weights, renames the router/bias, and drops the MTP layers (model.mtp.*) plus the vision/audio/speech towers. convert_to_hf emits bf16 hub-format keys.
  • mini_moe: mimo_v2 preset verifying against the Xiaomi remote code, plus two small generic hooks — post_create_fn (the remote code leaves sink/correction biases as torch.empty; the preset fills them with deterministic non-trivial values) and per-preset attn_implementation (the remote code is eager-only under transformers 5.x).

Validation

mini_moe create + verify (fp32, CUDA, vs XiaomiMiMo/MiMo-V2.5 remote code, eager, sinks + window binding at seq 64 > window 32):

HF vs PrimeRL max logits diff: 0.000002
HF -> PrimeRL -> HF weight roundtrip verified

FA2 path cross-check (bf16): packed two-sample input through the FA2 path (padded-V full attention + chunked windowed sink attention) vs per-sample SDPA — max logit diff 0.0156 at mean |logit| 0.36, i.e. plain bf16 kernel-order noise; packing isolation holds (cross-sample leakage would dominate).

SFT smoke on the FA2 path (reverse-text, 200 steps): loss 11.93 → 3.28, ~15.7k tokens/s on a single RTX PRO 6000 — exercises the windowed sink attention with gradients end-to-end.

Why there is no KL table yet

vLLM serves MiMo-V2's sinks only via FA3 (Hopper SM90) or FA4 (SM100): on this dev box (RTX PRO 6000, SM 12.0) get_flash_attn_version falls back to FA2 and the engine asserts Sinks are only supported in FlashAttention 3. The model cannot be served by vLLM on this machine at all, so the 20-step math KL table needs a Hopper/SM100 node — to be added before undrafting.

Follow-ups

  • Weight broadcast to an FP8-initialized vLLM engine needs convert_layer_to_vllm_kernel with block-FP8 re-quantization (glm_moe_dsa precedent); until then serve the dequantized bf16 checkpoint.
  • MTP modules (3×, used by MiMo for speculative decoding and RL efficiency upstream) are dropped — same trade-off as other models: spec-decode acceptance may drift over training.
  • No MiMo renderer in deps/renderers — use [orchestrator.renderer] name = "default".

🤖 Generated with Claude Code

Implement the text backbone of MiMo-V2.5 (310B MoE, 15B activated/token):
modeling code, HF<->PrimeRL weight conversion with block-wise FP8
dequantization, registration, and mini-model verification against the hub
remote code.

Key points:
- Hybrid attention: per-layer full vs sliding-window (128) via
  hybrid_layer_pattern, with different KV head counts per type
- Asymmetric head dims (qk 192 / v 128): flash attention runs with the value
  zero-padded to the qk head dim and the output sliced back (exact)
- SWA layers carry a frozen per-head attention sink bias; since flash
  attention cannot express the sink and its LSE is not differentiable, SWA
  layers use an exact chunked windowed implementation (O(total*window)
  memory, packing-aware via cu_seqlens)
- Partial RoPE (factor 0.334) with per-type theta (1e7 full / 1e4 SWA),
  value scale 0.707
- noaux_tc sigmoid routing with n_group=1 maps onto TokenChoiceTopKRouter
  (e_score_correction_bias folds into the expert_bias buffer)
- Fused qkv_proj kept as the module layout so hub checkpoint keys map 1:1
- Conversion dequantizes block-128 FP8 (weight_scale_inv) to bf16 and drops
  the MTP layers (model.mtp.*) and vision/audio/speech towers
- mini_moe: add post_create_fn and per-preset attn_implementation hooks
  (remote code is eager-only and leaves sink/correction biases empty)

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant