Add Xiaomi MiMo-V2.5 MoE model support#2791
Draft
hallerite wants to merge 1 commit into
Draft
Conversation
Implement the text backbone of MiMo-V2.5 (310B MoE, 15B activated/token): modeling code, HF<->PrimeRL weight conversion with block-wise FP8 dequantization, registration, and mini-model verification against the hub remote code. Key points: - Hybrid attention: per-layer full vs sliding-window (128) via hybrid_layer_pattern, with different KV head counts per type - Asymmetric head dims (qk 192 / v 128): flash attention runs with the value zero-padded to the qk head dim and the output sliced back (exact) - SWA layers carry a frozen per-head attention sink bias; since flash attention cannot express the sink and its LSE is not differentiable, SWA layers use an exact chunked windowed implementation (O(total*window) memory, packing-aware via cu_seqlens) - Partial RoPE (factor 0.334) with per-type theta (1e7 full / 1e4 SWA), value scale 0.707 - noaux_tc sigmoid routing with n_group=1 maps onto TokenChoiceTopKRouter (e_score_correction_bias folds into the expert_bias buffer) - Fused qkv_proj kept as the module layout so hub checkpoint keys map 1:1 - Conversion dequantizes block-128 FP8 (weight_scale_inv) to bf16 and drops the MTP layers (model.mtp.*) and vision/audio/speech towers - mini_moe: add post_create_fn and per-preset attn_implementation hooks (remote code is eager-only and leaves sink/correction biases empty) Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds the text backbone of Xiaomi MiMo-V2.5 (310B MoE, 15B activated/token, 256 experts top-8) as a custom trainer model.
What's in here
src/prime_rl/trainer/models/mimo_v2/): the architecture has several features outside the shared layer primitives, so attention is implemented model-locally:hybrid_layer_pattern, with different KV head counts per type (4 full / 8 SWA) and per-type RoPE bases (1e7 full / 1e4 SWA, partial rotary factor 0.334).attention_sink_bias(an extra softmax column). Flash attention can't express the sink and its returned LSE is not differentiable, so SWA layers use an exact chunked windowed implementation: each query block materializes logits only against its (window+chunk)-sized key slice — O(total·window) memory, packing-aware viacu_seqlens, fp32 softmax.attention_value_scale=0.707applied to V pre-attention, matching the reference.qkv_projkept as the module layout so hub checkpoint keys map 1:1 (no shape inference needed in conversion; split sizes come from config at init).noaux_tcsigmoid routing withn_group=1degenerates to exactlyTokenChoiceTopKRoutersemantics (bias affects selection only, unbiased weights, top-k normalization);gate.e_score_correction_biasfolds into theexpert_biasbuffer. No shared experts. Layer 0 dense viamoe_layer_freq.weight_scale_inv, newdequantize_fp8_blockwiseinfp8.py) to bf16, stacks per-expert weights, renames the router/bias, and drops the MTP layers (model.mtp.*) plus the vision/audio/speech towers.convert_to_hfemits bf16 hub-format keys.mimo_v2preset verifying against the Xiaomi remote code, plus two small generic hooks —post_create_fn(the remote code leaves sink/correction biases astorch.empty; the preset fills them with deterministic non-trivial values) and per-presetattn_implementation(the remote code is eager-only under transformers 5.x).Validation
mini_moecreate + verify (fp32, CUDA, vsXiaomiMiMo/MiMo-V2.5remote code, eager, sinks + window binding at seq 64 > window 32):FA2 path cross-check (bf16): packed two-sample input through the FA2 path (padded-V full attention + chunked windowed sink attention) vs per-sample SDPA — max logit diff 0.0156 at mean |logit| 0.36, i.e. plain bf16 kernel-order noise; packing isolation holds (cross-sample leakage would dominate).
SFT smoke on the FA2 path (reverse-text, 200 steps): loss 11.93 → 3.28, ~15.7k tokens/s on a single RTX PRO 6000 — exercises the windowed sink attention with gradients end-to-end.
Why there is no KL table yet
vLLM serves MiMo-V2's sinks only via FA3 (Hopper SM90) or FA4 (SM100): on this dev box (RTX PRO 6000, SM 12.0)
get_flash_attn_versionfalls back to FA2 and the engine assertsSinks are only supported in FlashAttention 3. The model cannot be served by vLLM on this machine at all, so the 20-stepmathKL table needs a Hopper/SM100 node — to be added before undrafting.Follow-ups
convert_layer_to_vllm_kernelwith block-FP8 re-quantization (glm_moe_dsa precedent); until then serve the dequantized bf16 checkpoint.deps/renderers— use[orchestrator.renderer] name = "default".🤖 Generated with Claude Code