diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index dc03ba82f..2a90caac6 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2583,7 +2583,9 @@ def export( self.model.config, fbs if self.continuous_batching else bs, seq_len ) enable_chunking = kwargs.get("enable_chunking", False) + # breakpoint() if prefill_only: + # breakpoint() if not enable_chunking and self.continuous_batching: raise NotImplementedError( "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" @@ -2597,7 +2599,12 @@ def export( if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH else seq_len ) - kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len + # breakpoint() + kv_cache_shape[2] = ( + seq_len + (0 if self.model.config.sliding_window is None else self.model.config.sliding_window) + if enable_chunking + else seq_len + ) else: self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) @@ -2605,8 +2612,11 @@ def export( self.hash_params.pop("NUM_FFN_BLOCKS", None) self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) + # breakpoint() if kwargs.get("retain_full_kv", False): - kv_cache_shape[2] = seq_len + self.model.config.sliding_window + kv_cache_shape[2] = seq_len + ( + 0 if self.model.config.sliding_window is None else self.model.config.sliding_window + ) self.hash_params["retain_full_kv"] = True example_inputs = { @@ -2695,6 +2705,7 @@ def export( vocab_size=self.model.config.vocab_size, qaic_config=self.model.qaic_config, ) + # breakpoint() return self._export( example_inputs, output_names, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 9e021851b..a3f903b40 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -415,6 +415,7 @@ QEffQwen3Model, ) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffPrefillQwen3MoeSparseMoeBlock, QEffQwen3MoeAttention, QEffQwen3MoeDecoderLayer, QEffQwen3MoeForCausalLM, @@ -650,32 +651,41 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: class PrefillOnlyTransform(ModuleMappingTransform): + # breakpoint() _module_mapping = { QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + # QEffQwen3MoeSparseMoeBlock:QEffPrefillQwen3MoeSparseMoeBlock, + # # QEffQwen3MoeModel:QEffPrefillOnlyQwen3MoeModel, + # QEffQwen3MoeAttention: QEffPrefillOnlyQwen3MoeAttention, } class PrefillOnlyChunkedTransform(ModuleMappingTransform): + # breakpoint() _module_mapping = { QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, + QEffQwen3MoeSparseMoeBlock: QEffPrefillQwen3MoeSparseMoeBlock, } class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): + # breakpoint() _module_mapping = { QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + QEffPrefillQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, } class RevertPrefillOnlyTransform(ModuleMappingTransform): + # breakpoint() _module_mapping = { **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()}, **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index cbd80d8ca..441d4446c 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -101,10 +101,10 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: float, ): + # breakpoint() key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where( @@ -118,25 +118,13 @@ def eager_attention_forward( return attn_output, attn_weights -class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): - def __qeff_init__(self): - self.gate_proj_w = [] - self.up_proj_w = [] - self.down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) - self.up_proj_w.append(self.experts[e].up_proj.weight.T) - self.down_proj_w.append(self.experts[e].down_proj.weight.T) - self.gate_proj_w = torch.stack(self.gate_proj_w) - self.up_proj_w = torch.stack(self.up_proj_w) - self.down_proj_w = torch.stack(self.down_proj_w) - - def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +class QEffPrefillQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # breakpoint() B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) - + # breakpoint() router_logits = self.gate(x) # [T, E] prob = F.softmax(router_logits, -1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, -1) @@ -145,27 +133,39 @@ def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch. top_w = top_w.to(x.dtype) masked_logits = torch.zeros_like(router_logits) masked_logits.scatter_(1, top_i, top_w) - # Routing weights for each expert [T, E] routing_weights = masked_logits - # ────────────────── allocate the output tensor ───── expert_out = x.new_zeros((T, H)) # accumulation buffer - # ───────────────────────── Expert computation loop ───────────────────────────── for e in range(self.num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I] - W_d = self.experts[e].down_proj # [I, H] - gate = W_g(x) # [T, I] - up = W_u(x) # [T, I] - down = W_d(up * self.experts[e].act_fn(gate)) # [T, H] - + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I] + W_d = self.experts[e].down_proj.weight.T # [I, H] + gate = x @ W_g # [T, I] + up = x @ W_u # [T, I] + down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H] masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out)) expert_out += masked_down return expert_out.view(B, S, H), router_logits + +class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) + self.up_proj_w = torch.stack(self.up_proj_w) + self.down_proj_w = torch.stack(self.down_proj_w) + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # breakpoint() B, S, H = hidden_states.shape T = B * S hidden_states = hidden_states.view(T, H) @@ -240,6 +240,56 @@ def forward( return attn_output, attn_weights +class QEffPrefillOnlyQwen3MoeAttention(Qwen3MoeAttention): + def __qeff_init__(self): + self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + breakpoint() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + class QEffQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): def forward( self, @@ -370,6 +420,77 @@ def forward( ) +class QEffPrefillOnlyQwen3MoeModel(Qwen3MoeModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + batch_index: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> MoeModelOutputWithPast: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + + class QEffQwen3MoeForCausalLM(Qwen3MoeForCausalLM): def forward( self, diff --git a/examples/qwen3moe_disagg_mode_with_chunking.py b/examples/qwen3moe_disagg_mode_with_chunking.py new file mode 100644 index 000000000..19fba549f --- /dev/null +++ b/examples/qwen3moe_disagg_mode_with_chunking.py @@ -0,0 +1,132 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +prompt = """ +Explain quantum computing in simple terms. +""" +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 128 * 3 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, +) + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 +# prefill_qpc_path = "/home/dipankar/.cache/qeff_models/Qwen3MoeForCausalLM/Qwen3MoeForCausalLM-2fff95dd3d8e1907/qpc-0d9874dc75da1555/qpc" + +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=2, + split_retained_state_io=True, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + # use_onnx_subfunctions=True, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +prefill_session = QAICInferenceSession(prefill_qpc_path) +decode_session = QAICInferenceSession(decode_qpc_path) + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +st = time.time() +for i in range(generation_len - 2): + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")