Skip to content

Add Tencent Hy3 (HYV3) MoE model support#2789

Draft
hallerite wants to merge 1 commit into
mainfrom
feat/hy3-model
Draft

Add Tencent Hy3 (HYV3) MoE model support#2789
hallerite wants to merge 1 commit into
mainfrom
feat/hy3-model

Conversation

@hallerite

Copy link
Copy Markdown
Member

Adds Tencent Hy3-preview (295B MoE, 21B activated/token, 192 experts top-8 + 1 shared) as a custom trainer model.

What's in here

  • Modeling (src/prime_rl/trainer/models/hy_v3/): glm4_moe-style decoder built from existing layer primitives — per-head QK norm (Apertus-style), sigmoid router with e_score_correction_bias-based selection + top-k normalization + router_scaling_factor=2.826 (maps 1:1 onto TokenChoiceTopKRouter semantics), 1 dense + 79 sparse layers, shared expert via BCFeedForward.
  • Conversion: handles both source formats — the hub checkpoint (per-expert experts, mlp.router.gate.weight, mlp.expert_bias, mlp.shared_mlp.*) and the transformers 5.x in-memory format (fused gate_up_proj, mlp.gate.weight, mlp.e_score_correction_bias, mlp.shared_experts.*). convert_to_hf emits the hub format, which both vLLM and transformers load natively. The MTP layer (model.layers.80, speculative decoding only) is dropped on load.
  • Config: parses the hub config.json field names directly (first_k_dense_replace, qk_norm, route_norm, ...), derives mlp_layer_types for compatibility with transformers' native HYV3ForCausalLM, and emits expert_hidden_dim so checkpoints saved by the trainer stay loadable by vLLM. The hub's use_grouped_mm: false (which refers to HF's eager experts impl) is harmlessly overridden by the trainer from ModelConfig.moe_use_grouped_mm.
  • mini_moe preset (hy_v3) verifying against transformers' native implementation.
  • Docs: fixed the stale SFT warm-up command in docs/development.md (--data.type sft, --ckpt.weights.save-sharded false).

No inference-side changes needed: vLLM 0.22.0 (our pin) already ships HYV3ForCausalLM, the hy_v3 reasoning/tool parsers, and MTP speculative decoding; transformers 5.6.2 has the architecture natively (no trust_remote_code).

Validation

mini_moe create + verify (fp32, CUDA):

HF vs PrimeRL max logits diff: 0.000001
HF -> PrimeRL -> HF weight roundtrip verified

SFT warm-up on reverse-text (200 steps, loss 11.75 → 4.27), then the full RL stack (configs/ci/integration/reverse_text_moe/start.toml, batch 128, 20 steps, FA2 trainer vs vLLM triton MoE backend):

Step Mismatch KL Step Mismatch KL
0 0.0000 10 0.0000
1 0.0000 11 0.0001
2 0.0001 12 0.0000
3 0.0000 13 0.0000
4 0.0000 14 0.0001
5 0.0000 15 0.0000
6 0.0001 16 0.0000
7 0.0001 17 0.0001
8 0.0001 18 0.0000
9 0.0001 19 0.0001

All entries well under the 0.015 bar. Note: this is the reverse-text smoke env on the mini model (2× RTX PRO 6000); the documented math + batch_size=64 table on the real checkpoint needs a node that can hold 295B.

Known limitations / follow-ups

  • Renderer: no Hunyuan renderer in deps/renderers yet — runs need [orchestrator.renderer] name = "default" (chat-template fallback). A dedicated renderer handling Hy3's reasoning_effort kwarg + tool format is a follow-up (renderers repo).
  • MTP head stays frozen during RL (dropped from the trainer); spec-decode acceptance may drift as the policy trains — disable speculative decoding for RL serving if it matters.
  • torch.compile hits a pre-existing Inductor unbacked-symint assertion on this dev box — reproduces identically with glm4_moe, so it's environmental, not model-specific.
  • Blackwell (SM 12.0) note: vLLM's default FlashInfer CUTLASS MoE backend requires CUDA ≥ 12.9 — pass --vllm-extra.moe_backend triton on such boxes.

🤖 Generated with Claude Code

Implement full pipeline for Hy3-preview (295B MoE, 21B activated/token):
modeling code, HF<->PrimeRL weight conversion, registration, and mini-model
verification.

Key points:
- Sigmoid router with e_score_correction_bias selection, route norm, and
  router_scaling_factor=2.826, mapped onto TokenChoiceTopKRouter semantics
- Per-head QK norm (Apertus-style) via existing AttentionConfig
- Conversion handles both the hub checkpoint format (per-expert experts,
  mlp.router.gate, mlp.expert_bias, mlp.shared_mlp) and the transformers
  in-memory format (fused gate_up_proj, mlp.gate, e_score_correction_bias,
  mlp.shared_experts); convert_to_hf emits the hub format which both vLLM
  and transformers load natively
- MTP layer (model.layers.80, speculative decoding only) dropped on load
- Config parses the hub config.json field names directly; emits
  expert_hidden_dim for vLLM compatibility
- Fix stale SFT warm-up command in docs/development.md

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant