From 7baf689239677c3af8ce2be216ad82a6cbf6fdd3 Mon Sep 17 00:00:00 2001
From: r-cloudforge <266043551+r-cloudforge@users.noreply.github.com>
Date: Tue, 14 Apr 2026 12:04:39 +0200
Subject: [PATCH] [Feature] add MiniCPM4/4.1 model support
---
docs/best_practices/MiniCPM4-8B.md | 104 ++++
docs/supported_models.md | 1 +
fastdeploy/model_executor/models/minicpm4.py | 516 +++++++++++++++++++
tests/model_executor/test_minicpm4.py | 514 ++++++++++++++++++
4 files changed, 1135 insertions(+)
create mode 100644 docs/best_practices/MiniCPM4-8B.md
create mode 100644 fastdeploy/model_executor/models/minicpm4.py
create mode 100644 tests/model_executor/test_minicpm4.py
diff --git a/docs/best_practices/MiniCPM4-8B.md b/docs/best_practices/MiniCPM4-8B.md
new file mode 100644
index 00000000000..6d758bb1c45
--- /dev/null
+++ b/docs/best_practices/MiniCPM4-8B.md
@@ -0,0 +1,104 @@
+# MiniCPM4/4.1-8B
+
+## I. Environment Preparation
+
+### 1.1 Hardware Requirements
+The minimum number of GPUs required to deploy `MiniCPM4.1-8B` on the following hardware for each quantization is as follows:
+
+| | BF16 | WINT8 | WINT4 | FP8 |
+|-----|-----|-----|-----|-----|
+|H800 80GB| 1 | 1 | 1 | 1 |
+|A800 80GB| 1 | 1 | 1 | / |
+|H20 96GB| 1 | 1 | 1 | 1 |
+|L20 48GB| 1 | 1 | 1 | 1 |
+|A30 40GB| / | 1 | 1 | / |
+|A10 24GB| / | 1 | 1 | / |
+|V100 32GB| / | 1 | 1 | / |
+
+**Tips:**
+1. MiniCPM4.1-8B is a dense 8B model — a single GPU is sufficient for inference at all supported quantization levels.
+2. For hardware not listed in the table, you can estimate whether it can be deployed based on the GPU memory. BF16 requires ~16GB, WINT8 ~8GB, WINT4 ~4GB.
+
+### 1.2 Install FastDeploy
+- Installation: For detail, please refer to [FastDeploy Installation](../get_started/installation/README.md).
+- Model Download: For detail, please refer to [Supported Models](../supported_models.md).
+
+## II. How to Use
+
+### 2.1 Basic: Launching the Service
+
+**Example 1:** Deploying MiniCPM4.1-8B with WINT4 quantization
+
+```bash
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model openbmb/MiniCPM4.1-8B \
+ --tensor-parallel-size 1 \
+ --quantization wint4 \
+ --max-model-len 32768 \
+ --max-num-seqs 128
+```
+
+**Example 2:** Deploying MiniCPM4.1-8B with BF16 (full precision)
+
+```bash
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model openbmb/MiniCPM4.1-8B \
+ --tensor-parallel-size 1 \
+ --max-model-len 32768 \
+ --max-num-seqs 64
+```
+
+- `--quantization`: Quantization strategy. Options: `wint8` / `wint4` / `block_wise_fp8` (Hopper required). Omit for BF16.
+- `--max-model-len`: Maximum number of tokens for the deployed service. MiniCPM4.1 supports up to 65,536 tokens with LongRoPE, but larger values increase GPU memory usage.
+
+For more parameter meanings and default settings, see [FastDeploy Parameter Documentation](../parameters.md).
+
+### 2.2 Sending Requests
+
+After the service starts, send requests via the OpenAI-compatible API:
+
+```bash
+curl http://localhost:8180/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "openbmb/MiniCPM4.1-8B",
+ "messages": [{"role": "user", "content": "What is the capital of France?"}],
+ "max_tokens": 512
+ }'
+```
+
+### 2.3 Advanced: How to Get Better Performance
+
+#### 2.3.1 Correctly Set Parameters That Match the Application Scenario
+Evaluate average input length, average output length, and maximum context length.
+- Set `--max-model-len` according to the maximum context length. For example, if the average input length is 1000 and the output length is 4000, then it is recommended to set it to 8192.
+
+#### 2.3.2 Prefix Caching
+**Idea:** The core idea of Prefix Caching is to avoid repeated calculations by caching the intermediate calculation results of the input sequence (KV Cache), thereby speeding up the response speed of multiple requests with the same prefix. For details, refer to [prefix-cache](../features/prefix_caching.md).
+
+**How to enable:**
+Since version 2.2 (including the develop branch), Prefix Caching has been enabled by default.
+
+#### 2.3.3 Chunked Prefill
+**Idea:** This strategy splits the prefill stage request into small-scale sub-chunks, and executes them in batches mixed with the decode request. For details, please refer to [Chunked Prefill](../features/chunked_prefill.md).
+
+**How to enable:**
+Since version 2.2 (including the develop branch), Chunked Prefill has been enabled by default.
+
+#### 2.3.4 CudaGraph
+**Idea:** CUDAGraph encapsulates GPU computing and memory operations into a re-executable graph, reducing CPU-GPU communication overhead and improving computing performance.
+
+**How to enable:**
+CUDAGraph has been enabled by default since version 2.3.
+
+## Model Architecture Notes
+
+MiniCPM4.1-8B uses μP (Maximal Update Parametrization) for training stability:
+- **Embedding scaling**: Output scaled by `scale_emb` (12×)
+- **Residual scaling**: Connections scaled by `scale_depth / √num_hidden_layers`
+- **LM head scaling**: Input scaled by `hidden_size / dim_model_base`
+
+These scaling factors are automatically read from the model's `config.json` and require no user configuration.
+
+## FAQ
+If you encounter any problems during use, please refer to [FAQ](./FAQ.md).
diff --git a/docs/supported_models.md b/docs/supported_models.md
index b0684affc11..caf1cf54f82 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -40,6 +40,7 @@ These models accept text input.
|⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.|
|⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.|
|⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
[最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.|
+|MINICPM4|BF16/WINT8/WINT4/FP8|[openbmb/MiniCPM4.1-8B](./best_practices/MiniCPM4-8B.md);
openbmb/MiniCPM4-8B|
## Multimodal Language Models
diff --git a/fastdeploy/model_executor/models/minicpm4.py b/fastdeploy/model_executor/models/minicpm4.py
new file mode 100644
index 00000000000..96a4d86b1ab
--- /dev/null
+++ b/fastdeploy/model_executor/models/minicpm4.py
@@ -0,0 +1,516 @@
+"""
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import annotations
+
+import math
+import re
+from functools import partial
+from typing import Dict
+
+import paddle
+from paddle import nn
+from paddleformers.transformers import PretrainedModel
+from paddleformers.utils.log import logger
+
+from fastdeploy.config import FDConfig, ModelConfig
+from fastdeploy.model_executor.forward_meta import ForwardMeta
+from fastdeploy.model_executor.graph_optimization.decorator import (
+ support_graph_optimization,
+)
+from fastdeploy.model_executor.layers.activation import SiluAndMul
+from fastdeploy.model_executor.layers.attention.attention import Attention
+from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
+from fastdeploy.model_executor.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
+from fastdeploy.model_executor.layers.normalization import RMSNorm
+from fastdeploy.model_executor.models.model_base import (
+ ModelCategory,
+ ModelForCasualLM,
+ ModelRegistry,
+)
+from fastdeploy.model_executor.utils import (
+ WeightsMapper,
+ default_weight_loader,
+ process_weights_after_loading,
+ process_weights_before_loading,
+)
+
+
+class MiniCPM4MLP(nn.Layer):
+ """ """
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.up_gate_proj = MergedColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.up_gate_proj",
+ input_size=fd_config.model_config.hidden_size,
+ output_size=fd_config.model_config.intermediate_size * 2,
+ with_bias=False,
+ activation=fd_config.model_config.hidden_act,
+ )
+
+ self.down_proj = RowParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.down_proj",
+ input_size=fd_config.model_config.intermediate_size,
+ output_size=fd_config.model_config.hidden_size,
+ with_bias=False,
+ )
+
+ self.act_fn = SiluAndMul(
+ fd_config=fd_config,
+ bias=getattr(self.up_gate_proj, "bias", None),
+ act_method=fd_config.model_config.hidden_act,
+ )
+
+ def load_state_dict(self, state_dict):
+ """ """
+ self.up_gate_proj.load_state_dict(state_dict)
+ self.down_proj.load_state_dict(state_dict)
+
+ def forward(self, x, forward_meta):
+ """ """
+ gate_up_out = self.up_gate_proj(x)
+ act_out = self.act_fn(gate_up_out)
+ down_out = self.down_proj(act_out)
+ return down_out
+
+
+class MiniCPM4Attention(nn.Layer):
+ """ """
+
+ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None:
+ super().__init__()
+
+ self.qkv_proj = QKVParallelLinear(fd_config=fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False)
+
+ self.o_proj = RowParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.o_proj",
+ input_size=fd_config.model_config.hidden_size,
+ output_size=fd_config.model_config.hidden_size,
+ )
+
+ self.attn = Attention(
+ fd_config=fd_config,
+ layer_id=layer_id,
+ prefix=prefix,
+ use_neox_rotary_style=True,
+ )
+
+ def load_state_dict(self, state_dict):
+ """ """
+ self.qkv_proj.load_state_dict(state_dict)
+ self.o_proj.load_state_dict(state_dict)
+ self.attn.load_state_dict(state_dict)
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ ):
+ """ """
+ qkv_out = self.qkv_proj(hidden_states)
+
+ atten_out = self.attn(
+ qkv=qkv_out,
+ forward_meta=forward_meta,
+ )
+ output = self.o_proj(atten_out)
+ return output
+
+
+class MiniCPM4DecoderLayer(nn.Layer):
+ """MiniCPM4 decoder layer with μP residual scaling."""
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ layer_id = int(prefix.split(sep=".")[-1])
+
+ self.self_attn = MiniCPM4Attention(
+ fd_config=fd_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.self_attn",
+ )
+
+ self.mlp = MiniCPM4MLP(
+ fd_config=fd_config,
+ prefix=f"{prefix}.mlp",
+ )
+
+ self.input_layernorm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{prefix}.input_layernorm",
+ )
+
+ self.post_attention_layernorm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{prefix}.post_attention_layernorm",
+ layer_id=layer_id,
+ )
+
+ # μP residual scaling: scale_depth / sqrt(num_hidden_layers)
+ scale_depth = getattr(fd_config.model_config, "scale_depth", 1.0)
+ num_hidden_layers = fd_config.model_config.num_hidden_layers
+ self.residual_scale = scale_depth / math.sqrt(num_hidden_layers)
+
+ def load_state_dict(self, state_dict):
+ """ """
+ self.self_attn.load_state_dict(state_dict)
+ self.mlp.load_state_dict(state_dict)
+ self.input_layernorm.load_state_dict(state_dict)
+ self.post_attention_layernorm.load_state_dict(state_dict)
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ residual: paddle.Tensor = None,
+ ):
+ """ """
+ # Self Attention
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual_input=residual, forward_meta=forward_meta
+ )
+
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ forward_meta=forward_meta,
+ )
+
+ # μP: scale attention output before residual add
+ hidden_states = hidden_states * self.residual_scale
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+
+ hidden_states = self.mlp(hidden_states, forward_meta)
+
+ # μP: scale MLP output before residual add
+ hidden_states = hidden_states * self.residual_scale
+
+ return hidden_states, residual
+
+
+@support_graph_optimization
+class MiniCPM4Model(nn.Layer):
+ """ """
+
+ def __init__(
+ self,
+ fd_config: FDConfig = None,
+ ):
+ super().__init__()
+
+ self.num_layers = fd_config.model_config.num_hidden_layers
+ fd_config.model_config.pretrained_config.prefix_name = "minicpm4"
+
+ # μP embedding scaling factor
+ self.scale_emb = getattr(fd_config.model_config, "scale_emb", 1)
+
+ self.embed_tokens = VocabParallelEmbedding(
+ fd_config=fd_config,
+ num_embeddings=fd_config.model_config.vocab_size,
+ embedding_dim=fd_config.model_config.hidden_size,
+ params_dtype=paddle.get_default_dtype,
+ prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
+ )
+
+ self.layers = nn.LayerList(
+ [
+ MiniCPM4DecoderLayer(
+ fd_config=fd_config,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ self.norm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
+ )
+
+ def load_state_dict(self, state_dict):
+ """
+ Load model parameters from a given state dictionary.
+
+ Args:
+ state_dict (dict[str, np.ndarray | paddle.Tensor]):
+ A dictionary containing model parameters, where keys are parameter names
+ and values are NumPy arrays or PaddlePaddle tensors.
+ """
+ self.embed_tokens.load_state_dict(state_dict)
+ self.norm.load_state_dict(state_dict)
+ for i in range(self.num_layers):
+ logger.info(f"Start load layer {i}")
+ self.layers[i].load_state_dict(state_dict)
+
+ def forward(
+ self,
+ ids_remove_padding: paddle.Tensor,
+ forward_meta: ForwardMeta,
+ ):
+
+ hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
+
+ # μP: scale embeddings
+ if self.scale_emb != 1:
+ hidden_states = hidden_states * self.scale_emb
+
+ residual = None
+
+ for i in range(self.num_layers):
+ hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
+
+ out = self.norm(hidden_states, residual)[0]
+
+ return out
+
+
+@ModelRegistry.register_model_class(
+ architecture="MiniCPMForCausalLM",
+ module_name="minicpm4",
+ category=[ModelCategory.TEXT_GENERATION],
+ primary_use=ModelCategory.TEXT_GENERATION,
+)
+class MiniCPM4ForCausalLM(ModelForCasualLM):
+ """
+ MiniCPM4ForCausalLM — supports MiniCPM4 and MiniCPM4.1 series models.
+
+ Key differences from Qwen2:
+ - μP (Maximal Update Parametrization) scaling:
+ * Embedding output scaled by `scale_emb` (default: 12)
+ * Residual connections scaled by `scale_depth / sqrt(num_hidden_layers)` (default: 1.4)
+ * LM head input scaled by `hidden_size / dim_model_base` (default: 4096/256 = 16)
+ - No QKV bias (attention_bias=false)
+ - LongRoPE position encoding
+ """
+
+ def __init__(self, fd_config: FDConfig):
+ super(MiniCPM4ForCausalLM, self).__init__(fd_config)
+
+ self.fd_config = fd_config
+ self.minicpm4 = MiniCPM4Model(fd_config=fd_config)
+
+ self.ori_vocab_size = fd_config.model_config.ori_vocab_size
+ self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
+ self.lm_head = ParallelLMHead(
+ fd_config=fd_config,
+ embedding_dim=fd_config.model_config.hidden_size,
+ num_embeddings=fd_config.model_config.vocab_size,
+ prefix="lm_head",
+ )
+
+ # μP: lm_head input scaling factor = hidden_size / dim_model_base
+ dim_model_base = getattr(fd_config.model_config, "dim_model_base", None)
+ hidden_size = fd_config.model_config.hidden_size
+ if dim_model_base is not None and dim_model_base > 0:
+ self.lm_head_scale = hidden_size / dim_model_base
+ else:
+ self.lm_head_scale = 1.0
+
+ self.process_weights_before_loading_fn = process_weights_before_loading(
+ mapper=(
+ WeightsMapper(orig_to_new_prefix={"model.": "minicpm4."})
+ if self.fd_config.model_config.model_format == "torch"
+ else None
+ ),
+ )
+
+ @paddle.no_grad()
+ def load_weights(self, weights_iterator) -> None:
+ """
+ Load model parameters from a given weights_iterator object.
+
+ Args:
+ weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
+ """
+
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("up_gate_proj", "gate_proj", "gate"),
+ ("up_gate_proj", "up_proj", "up"),
+ ("embed_tokens.embeddings", "embed_tokens", None),
+ ("lm_head.linear", "lm_head", None),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
+ for loaded_weight_name, loaded_weight in weights_iterator:
+ logger.debug(f"Loading weight: {loaded_weight_name}")
+ loaded_weight_name = (
+ self.process_weights_before_loading_fn(loaded_weight_name)
+ if getattr(self, "process_weights_before_loading_fn", None)
+ else loaded_weight_name
+ )
+ if loaded_weight_name is None:
+ continue
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in loaded_weight_name:
+ continue
+ model_param_name = loaded_weight_name.replace(weight_name, param_name)
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ model_param_name = loaded_weight_name
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, loaded_weight)
+ model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
+ process_weights_after_loading_fn(model_sublayer_name, param)
+ if getattr(self, "tie_word_embeddings", False):
+ self.lm_head.linear.weight.set_value(
+ self.minicpm4.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype)
+ )
+
+ @classmethod
+ def name(self):
+ """ """
+ return "MiniCPMForCausalLM"
+
+ @paddle.no_grad()
+ def set_state_dict(self, state_dict):
+ """
+ Load model parameters from a given state dictionary.
+
+ Args:
+ state_dict (dict[str, np.ndarray | paddle.Tensor]):
+ A dictionary containing model parameters, where keys are parameter names
+ and values are NumPy arrays or PaddlePaddle tensors.
+ """
+ self.minicpm4.load_state_dict(state_dict)
+ self.lm_head.load_state_dict(state_dict)
+
+ def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
+ """ """
+ # μP: scale hidden states before lm_head
+ if self.lm_head_scale != 1.0:
+ hidden_states = hidden_states / self.lm_head_scale
+ logits = self.lm_head(hidden_states)
+ logits = logits.astype(paddle.float32)
+ logits[:, self.ori_vocab_size :] = -float("inf")
+
+ return logits
+
+ def forward(
+ self,
+ inputs: Dict,
+ forward_meta: ForwardMeta,
+ ):
+ ids_remove_padding = inputs["ids_remove_padding"]
+ hidden_states = self.minicpm4(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
+
+ return hidden_states
+
+ def clear_grpah_opt_backend(self):
+ """Clear graph optimization backend, the captured cuda graph will be cleaned"""
+ self.minicpm4.clear_grpah_opt_backend(fd_config=self.fd_config)
+
+
+class MiniCPM4PretrainedModel(PretrainedModel):
+ """
+ MiniCPM4PretrainedModel
+ """
+
+ config_class = FDConfig
+
+ def _init_weight(self, layer):
+ """
+ _init_weight
+ """
+ return None
+
+ @classmethod
+ def arch_name(self):
+ return "MiniCPMForCausalLM"
+
+ @classmethod
+ def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
+
+ from paddleformers.transformers.conversion_utils import split_or_merge_func
+
+ fn = split_or_merge_func(
+ is_split=is_split,
+ tensor_model_parallel_size=config.tensor_model_parallel_size,
+ tensor_parallel_rank=config.tensor_parallel_rank,
+ num_attention_heads=config.num_attention_heads,
+ )
+
+ def get_tensor_parallel_split_mappings(num_layers):
+ final_actions = {}
+
+ base_actions = {
+ "lm_head.weight": partial(fn, is_column=True),
+ # Row Linear
+ "embed_tokens.weight": partial(fn, is_column=False),
+ "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
+ "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
+ }
+
+ # Column Linear
+ if config.fuse_attention_qkv:
+ base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
+ else:
+ base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
+ # MiniCPM4 has no QKV bias, only need weight splits
+ if config.num_key_value_heads % config.tensor_model_parallel_size == 0:
+ base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
+ base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
+
+ base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
+ base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
+
+ for key, action in base_actions.items():
+ if "layers.0." in key:
+ for i in range(num_layers):
+ final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
+ final_actions[key] = action
+
+ return final_actions
+
+ mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
+
+ return mappings
diff --git a/tests/model_executor/test_minicpm4.py b/tests/model_executor/test_minicpm4.py
new file mode 100644
index 00000000000..4d97424ad1a
--- /dev/null
+++ b/tests/model_executor/test_minicpm4.py
@@ -0,0 +1,514 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import annotations
+
+import math
+from types import SimpleNamespace
+
+import numpy as np
+import paddle
+import pytest
+
+# Patch paddle.compat before importing fastdeploy (beta-2 compat)
+if not hasattr(paddle, "compat"):
+
+ class _PaddleCompat:
+ @staticmethod
+ def enable_torch_proxy(scope=None):
+ return None
+
+ paddle.compat = _PaddleCompat()
+
+from fastdeploy.model_executor.models import minicpm4
+
+# ── Stubs ───────────────────────────────────────────────────────────────────
+
+
+class _StubLinear(paddle.nn.Layer):
+ """Stub for MergedColumnParallelLinear / QKVParallelLinear / RowParallelLinear."""
+
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+
+ def forward(self, x):
+ return x
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubActivation(paddle.nn.Layer):
+ """Stub for SiluAndMul — identity pass-through."""
+
+ def __init__(self, *a, **kw):
+ super().__init__()
+
+ def forward(self, x):
+ return x
+
+
+class _StubRMSNorm(paddle.nn.Layer):
+ """Stub for RMSNorm — returns (hidden, residual) pair."""
+
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+
+ def forward(self, x, *args, **kwargs):
+ return x, x
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubAttention(paddle.nn.Layer):
+ """Stub for Attention — identity on qkv input."""
+
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+
+ def forward(self, qkv=None, forward_meta=None):
+ return qkv
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubEmbedding(paddle.nn.Layer):
+ """Stub for VocabParallelEmbedding."""
+
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+ self._h = kw.get("embedding_dim", 4)
+
+ def forward(self, ids_remove_padding=None, forward_meta=None):
+ return paddle.ones([ids_remove_padding.shape[0], self._h], dtype="float32")
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubLMHead(paddle.nn.Layer):
+ """Stub for ParallelLMHead — projects to vocab_size."""
+
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+ self._vocab = kw.get("num_embeddings", 128)
+
+ def forward(self, x):
+ return paddle.ones([x.shape[0], self._vocab], dtype=x.dtype)
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+# ── Config helper ───────────────────────────────────────────────────────────
+
+# Reference values from openbmb/MiniCPM4.1-8B
+_HIDDEN = 4
+_INTERMEDIATE = 8
+_LAYERS = 2
+_HEADS = 4
+_KV_HEADS = 2
+_HEAD_DIM = 2
+_VOCAB = 128
+_ORI_VOCAB = 100
+
+
+def _make_fd_config(
+ hidden_size=_HIDDEN,
+ num_layers=_LAYERS,
+ scale_emb=12,
+ scale_depth=1.4,
+ dim_model_base=256,
+):
+ mc = SimpleNamespace(
+ hidden_size=hidden_size,
+ intermediate_size=_INTERMEDIATE,
+ num_hidden_layers=num_layers,
+ num_attention_heads=_HEADS,
+ num_key_value_heads=_KV_HEADS,
+ head_dim=_HEAD_DIM,
+ vocab_size=_VOCAB,
+ ori_vocab_size=_ORI_VOCAB,
+ rms_norm_eps=1e-5,
+ hidden_act="silu",
+ scale_emb=scale_emb,
+ scale_depth=scale_depth,
+ dim_model_base=dim_model_base,
+ tie_word_embeddings=False,
+ model_format="torch",
+ is_quantized=False,
+ pretrained_config=SimpleNamespace(prefix_name="minicpm4"),
+ fuse_attention_qkv=True,
+ tensor_model_parallel_size=1,
+ tensor_parallel_rank=0,
+ moe_layer_start_index=0,
+ )
+ return SimpleNamespace(
+ model_config=mc,
+ parallel_config=SimpleNamespace(
+ tensor_parallel_size=1,
+ tensor_parallel_rank=0,
+ tp_group=None,
+ expert_parallel_size=1,
+ use_sequence_parallel_moe=False,
+ ),
+ graph_opt_config=SimpleNamespace(graph_opt_level=0, use_cudagraph=False),
+ scheduler_config=SimpleNamespace(splitwise_role="prefill", max_num_seqs=1),
+ load_config=SimpleNamespace(
+ dynamic_load_weight=False,
+ load_choices="default_v0",
+ is_pre_sharded=False,
+ ),
+ quant_config=None,
+ )
+
+
+# ── Fixture ─────────────────────────────────────────────────────────────────
+
+
+@pytest.fixture()
+def mod(monkeypatch):
+ """Inject stubs into minicpm4 module for CPU-safe testing."""
+ monkeypatch.setattr(minicpm4, "MergedColumnParallelLinear", _StubLinear)
+ monkeypatch.setattr(minicpm4, "QKVParallelLinear", _StubLinear)
+ monkeypatch.setattr(minicpm4, "RowParallelLinear", _StubLinear)
+ monkeypatch.setattr(minicpm4, "SiluAndMul", _StubActivation)
+ monkeypatch.setattr(minicpm4, "RMSNorm", _StubRMSNorm)
+ monkeypatch.setattr(minicpm4, "Attention", _StubAttention)
+ monkeypatch.setattr(minicpm4, "VocabParallelEmbedding", _StubEmbedding)
+ monkeypatch.setattr(minicpm4, "ParallelLMHead", _StubLMHead)
+ return minicpm4
+
+
+# ── MLP tests ───────────────────────────────────────────────────────────────
+
+
+def test_mlp_forward(mod):
+ """MLP: up_gate_proj -> act_fn -> down_proj pass-through."""
+ fd = _make_fd_config()
+ mlp = mod.MiniCPM4MLP(fd_config=fd, prefix="minicpm4.layers.0.mlp")
+ x = paddle.ones([2, _HIDDEN], dtype="float32")
+ out = mlp.forward(x, forward_meta=None)
+ assert out.shape == [2, _HIDDEN]
+
+
+def test_mlp_load_state_dict(mod):
+ """MLP load_state_dict delegates to sub-layers."""
+ fd = _make_fd_config()
+ mlp = mod.MiniCPM4MLP(fd_config=fd, prefix="minicpm4.layers.0.mlp")
+ mlp.load_state_dict({})
+ assert mlp.up_gate_proj.load_state_dict_called
+ assert mlp.down_proj.load_state_dict_called
+
+
+# ── Attention tests ─────────────────────────────────────────────────────────
+
+
+def test_attention_forward(mod):
+ """Attention: qkv_proj -> attn -> o_proj pass-through."""
+ fd = _make_fd_config()
+ attn = mod.MiniCPM4Attention(fd_config=fd, layer_id=0, prefix="minicpm4.layers.0.self_attn")
+ x = paddle.ones([2, _HIDDEN], dtype="float32")
+ meta = SimpleNamespace()
+ out = attn.forward(forward_meta=meta, hidden_states=x)
+ assert out.shape == [2, _HIDDEN]
+
+
+def test_attention_load_state_dict(mod):
+ """Attention load_state_dict delegates to sub-layers."""
+ fd = _make_fd_config()
+ attn = mod.MiniCPM4Attention(fd_config=fd, layer_id=0, prefix="minicpm4.layers.0.self_attn")
+ attn.load_state_dict({})
+ assert attn.qkv_proj.load_state_dict_called
+ assert attn.o_proj.load_state_dict_called
+ assert attn.attn.load_state_dict_called
+
+
+# ── DecoderLayer tests ──────────────────────────────────────────────────────
+
+
+def test_decoder_layer_residual_scale(mod):
+ """DecoderLayer computes muP residual_scale = scale_depth / sqrt(N)."""
+ fd = _make_fd_config(scale_depth=1.4, num_layers=32)
+ layer = mod.MiniCPM4DecoderLayer(fd_config=fd, prefix="minicpm4.layers.0")
+ expected = 1.4 / math.sqrt(32)
+ assert abs(layer.residual_scale - expected) < 1e-10
+
+
+def test_decoder_layer_forward(mod):
+ """DecoderLayer forward applies muP scaling to both attn and MLP outputs."""
+ fd = _make_fd_config(scale_depth=1.4, num_layers=32)
+ layer = mod.MiniCPM4DecoderLayer(fd_config=fd, prefix="minicpm4.layers.0")
+ x = paddle.full([2, _HIDDEN], 2.0, dtype="float32")
+ meta = SimpleNamespace()
+ hidden, residual = layer.forward(forward_meta=meta, hidden_states=x, residual=None)
+
+ # Output should be scaled by residual_scale twice (attn + mlp)
+ scale = 1.4 / math.sqrt(32)
+ # With identity stubs: norm returns (x, x), attn/mlp are identity
+ # Path: norm(x,None)->(x,x), attn(x)->x, x*scale, norm(x*scale,x)->(x*scale,x*scale),
+ # mlp(x*scale)->x*scale, x*scale*scale
+ expected_hidden = 2.0 * scale * scale
+ np.testing.assert_allclose(hidden.numpy().mean(), expected_hidden, rtol=1e-5)
+ assert residual is not None
+
+
+def test_decoder_layer_load_state_dict(mod):
+ """DecoderLayer load_state_dict delegates to all sub-layers."""
+ fd = _make_fd_config()
+ layer = mod.MiniCPM4DecoderLayer(fd_config=fd, prefix="minicpm4.layers.0")
+ layer.load_state_dict({})
+ assert layer.self_attn.qkv_proj.load_state_dict_called
+ assert layer.mlp.up_gate_proj.load_state_dict_called
+ assert layer.input_layernorm.load_state_dict_called
+ assert layer.post_attention_layernorm.load_state_dict_called
+
+
+# ── Model tests ─────────────────────────────────────────────────────────────
+
+
+def test_model_forward_with_embedding_scale(mod):
+ """MiniCPM4Model applies scale_emb to embedding output."""
+ fd = _make_fd_config(scale_emb=12, num_layers=1)
+ model = mod.MiniCPM4Model(fd_config=fd)
+ ids = paddle.to_tensor([0, 1, 2], dtype="int64")
+ meta = SimpleNamespace()
+ out = model.forward(ids_remove_padding=ids, forward_meta=meta)
+ assert out.shape == [3, _HIDDEN]
+ # Embedding returns ones, scaled by 12, then through decoder layers and final norm
+ assert paddle.isfinite(out).all()
+
+
+def test_model_no_embedding_scale(mod):
+ """When scale_emb=1, no embedding scaling applied."""
+ fd = _make_fd_config(scale_emb=1, num_layers=1)
+ model = mod.MiniCPM4Model(fd_config=fd)
+ ids = paddle.to_tensor([0, 1], dtype="int64")
+ meta = SimpleNamespace()
+ out = model.forward(ids_remove_padding=ids, forward_meta=meta)
+ assert out.shape == [2, _HIDDEN]
+
+
+def test_model_load_state_dict(mod):
+ """Model load_state_dict delegates to embed_tokens, norm, and layers."""
+ fd = _make_fd_config(num_layers=2)
+ model = mod.MiniCPM4Model(fd_config=fd)
+ model.load_state_dict({})
+ assert model.embed_tokens.load_state_dict_called
+ assert model.norm.load_state_dict_called
+ for layer in model.layers:
+ assert layer.self_attn.qkv_proj.load_state_dict_called
+
+
+# ── CausalLM tests ──────────────────────────────────────────────────────────
+
+
+def test_causallm_forward(mod):
+ """CausalLM full forward: ids -> model -> hidden_states."""
+ fd = _make_fd_config(num_layers=1)
+ model = mod.MiniCPM4ForCausalLM(fd_config=fd)
+ ids = paddle.to_tensor([0, 1, 2], dtype="int64")
+ meta = SimpleNamespace()
+ inputs = {"ids_remove_padding": ids}
+ hidden = model.forward(inputs=inputs, forward_meta=meta)
+ assert hidden.shape == [3, _HIDDEN]
+
+
+def test_causallm_compute_logits_mup_scaling(mod):
+ """compute_logits applies muP scaling: hidden /= (hidden_size / dim_model_base)."""
+ fd = _make_fd_config(dim_model_base=2) # lm_head_scale = 4/2 = 2.0
+ model = mod.MiniCPM4ForCausalLM(fd_config=fd)
+ assert model.lm_head_scale == 2.0
+
+ hidden = paddle.full([2, _HIDDEN], 4.0, dtype="float32")
+ logits = model.compute_logits(hidden, forward_meta=None)
+ assert logits.dtype == paddle.float32
+ assert logits.shape == [2, _VOCAB]
+
+
+def test_causallm_compute_logits_vocab_mask(mod):
+ """compute_logits masks extended vocab positions to -inf."""
+ fd = _make_fd_config()
+ model = mod.MiniCPM4ForCausalLM(fd_config=fd)
+ hidden = paddle.ones([2, _HIDDEN], dtype="float32")
+ logits = model.compute_logits(hidden, forward_meta=None)
+
+ # Valid vocab
+ assert paddle.isfinite(logits[:, :_ORI_VOCAB]).all()
+ # Extended vocab -> -inf
+ assert paddle.isinf(logits[:, _ORI_VOCAB:]).all()
+ assert (logits[:, _ORI_VOCAB:] < 0).all()
+
+
+def test_causallm_lm_head_scale_fallback(mod):
+ """When dim_model_base is None, lm_head_scale defaults to 1.0."""
+ fd = _make_fd_config(dim_model_base=None)
+ fd.model_config.dim_model_base = None
+ model = mod.MiniCPM4ForCausalLM(fd_config=fd)
+ assert model.lm_head_scale == 1.0
+
+
+def test_causallm_set_state_dict(mod):
+ """set_state_dict delegates to model and lm_head."""
+ fd = _make_fd_config(num_layers=1)
+ model = mod.MiniCPM4ForCausalLM(fd_config=fd)
+ model.set_state_dict({})
+ assert model.minicpm4.embed_tokens.load_state_dict_called
+ assert model.lm_head.load_state_dict_called
+
+
+def test_causallm_name(mod):
+ """Class name method returns 'MiniCPMForCausalLM'."""
+ assert mod.MiniCPM4ForCausalLM.name() == "MiniCPMForCausalLM"
+
+
+def test_causallm_tie_word_embeddings(mod):
+ """When tie_word_embeddings=True, load_weights sets lm_head from embed."""
+ fd = _make_fd_config()
+ fd.model_config.tie_word_embeddings = True
+ model = mod.MiniCPM4ForCausalLM(fd_config=fd)
+ # tie_word_embeddings flag is read
+ assert model.tie_word_embeddings is True
+
+
+# ── Weight loading & mapping tests ──────────────────────────────────────────
+
+
+def test_weights_mapper_prefix_rename():
+ """WeightsMapper renames 'model.' prefix to 'minicpm4.' for torch format."""
+ mapper = minicpm4.WeightsMapper(orig_to_new_prefix={"model.": "minicpm4."})
+ assert mapper.apply("model.layers.0.self_attn.q_proj.weight") == "minicpm4.layers.0.self_attn.q_proj.weight"
+ assert mapper.apply("model.embed_tokens.weight") == "minicpm4.embed_tokens.weight"
+ # lm_head has no 'model.' prefix -- unchanged
+ assert mapper.apply("lm_head.weight") == "lm_head.weight"
+
+
+def test_stacked_params_qkv():
+ """q_proj, k_proj, v_proj map to qkv_proj with correct shard_id."""
+ stacked = [
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("up_gate_proj", "gate_proj", "gate"),
+ ("up_gate_proj", "up_proj", "up"),
+ ("embed_tokens.embeddings", "embed_tokens", None),
+ ("lm_head.linear", "lm_head", None),
+ ]
+ qkv = {wn: (pn, sid) for pn, wn, sid in stacked if sid in ("q", "k", "v")}
+ assert qkv["q_proj"] == ("qkv_proj", "q")
+ assert qkv["k_proj"] == ("qkv_proj", "k")
+ assert qkv["v_proj"] == ("qkv_proj", "v")
+
+
+def test_stacked_params_gate_up():
+ """gate_proj, up_proj map to up_gate_proj."""
+ stacked = [
+ ("up_gate_proj", "gate_proj", "gate"),
+ ("up_gate_proj", "up_proj", "up"),
+ ]
+ gu = {wn: (pn, sid) for pn, wn, sid in stacked}
+ assert gu["gate_proj"] == ("up_gate_proj", "gate")
+ assert gu["up_proj"] == ("up_gate_proj", "up")
+
+
+# ── TP mappings tests ───────────────────────────────────────────────────────
+
+
+def test_tp_mappings_split_keys():
+ """TP mapping generates correct layer-indexed split actions."""
+ cfg = SimpleNamespace(
+ tensor_model_parallel_size=2,
+ tensor_parallel_rank=0,
+ num_attention_heads=_HEADS,
+ num_key_value_heads=_KV_HEADS,
+ hidden_size=_HIDDEN,
+ num_hidden_layers=2,
+ fuse_attention_qkv=True,
+ moe_layer_start_index=0,
+ )
+ mappings = minicpm4.MiniCPM4PretrainedModel._get_tensor_parallel_mappings(cfg, is_split=True)
+
+ # Should have per-layer keys
+ assert "layers.0.self_attn.qkv_proj.weight" in mappings
+ assert "layers.1.self_attn.qkv_proj.weight" in mappings
+ assert "layers.0.mlp.gate_proj.weight" in mappings
+ assert "lm_head.weight" in mappings
+
+
+def test_tp_mappings_non_fused_qkv():
+ """TP mapping handles unfused q/k/v separate weights."""
+ cfg = SimpleNamespace(
+ tensor_model_parallel_size=2,
+ tensor_parallel_rank=0,
+ num_attention_heads=_HEADS,
+ num_key_value_heads=_KV_HEADS,
+ hidden_size=_HIDDEN,
+ num_hidden_layers=1,
+ fuse_attention_qkv=False,
+ moe_layer_start_index=0,
+ )
+ mappings = minicpm4.MiniCPM4PretrainedModel._get_tensor_parallel_mappings(cfg, is_split=True)
+
+ assert "layers.0.self_attn.q_proj.weight" in mappings
+ assert "layers.0.self_attn.k_proj.weight" in mappings
+ assert "layers.0.self_attn.v_proj.weight" in mappings
+
+
+def test_tp_mappings_round_trip():
+ """Split then merge round-trip for QKV weight."""
+ cfg = SimpleNamespace(
+ tensor_model_parallel_size=2,
+ tensor_parallel_rank=None,
+ num_attention_heads=_HEADS,
+ num_key_value_heads=_KV_HEADS,
+ hidden_size=_HIDDEN,
+ num_hidden_layers=1,
+ fuse_attention_qkv=True,
+ moe_layer_start_index=0,
+ )
+ split_map = minicpm4.MiniCPM4PretrainedModel._get_tensor_parallel_mappings(cfg, is_split=True)
+ merge_map = minicpm4.MiniCPM4PretrainedModel._get_tensor_parallel_mappings(cfg, is_split=False)
+
+ key = "layers.0.mlp.gate_proj.weight"
+ w = np.arange(32, dtype=np.float32).reshape(8, _HIDDEN)
+ parts = split_map[key](w)
+ assert len(parts) == 2
+ merged = merge_map[key](parts)
+ np.testing.assert_array_equal(merged, w)
+
+
+# ── Registration test ───────────────────────────────────────────────────────
+
+
+def test_registration_architecture():
+ """MiniCPM4ForCausalLM is registered as 'MiniCPMForCausalLM'."""
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ registry = ModelRegistry()
+ model_info, arch = registry.inspect_model_cls(["MiniCPMForCausalLM"])
+ assert arch == "MiniCPMForCausalLM"
+ assert model_info.module_path == "minicpm4"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])