Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,7 +2583,9 @@ def export(
self.model.config, fbs if self.continuous_batching else bs, seq_len
)
enable_chunking = kwargs.get("enable_chunking", False)
# breakpoint()
if prefill_only:
# breakpoint()
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
Expand All @@ -2597,16 +2599,24 @@ def export(
if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
else seq_len
)
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
# breakpoint()
kv_cache_shape[2] = (
seq_len + (0 if self.model.config.sliding_window is None else self.model.config.sliding_window)
if enable_chunking
else seq_len
)
else:
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
self.hash_params.pop("NUM_Q_BLOCKS", None)
self.hash_params.pop("NUM_FFN_BLOCKS", None)
self.hash_params.pop("ENABLE_OPT_SWA", None)
self.hash_params.pop("chunking", None)
# breakpoint()
if kwargs.get("retain_full_kv", False):
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
kv_cache_shape[2] = seq_len + (
0 if self.model.config.sliding_window is None else self.model.config.sliding_window
)
self.hash_params["retain_full_kv"] = True

example_inputs = {
Expand Down Expand Up @@ -2695,6 +2705,7 @@ def export(
vocab_size=self.model.config.vocab_size,
qaic_config=self.model.qaic_config,
)
# breakpoint()
return self._export(
example_inputs,
output_names,
Expand Down
10 changes: 10 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@
QEffQwen3Model,
)
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
QEffPrefillQwen3MoeSparseMoeBlock,
QEffQwen3MoeAttention,
QEffQwen3MoeDecoderLayer,
QEffQwen3MoeForCausalLM,
Expand Down Expand Up @@ -650,32 +651,41 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:


class PrefillOnlyTransform(ModuleMappingTransform):
# breakpoint()
_module_mapping = {
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffGptOssAttention: QEffPrefillOnlyGptOssAttention,
QEffGptOssMLP: QEffPrefillOnlyGptOssMLP,
# QEffQwen3MoeSparseMoeBlock:QEffPrefillQwen3MoeSparseMoeBlock,
# # QEffQwen3MoeModel:QEffPrefillOnlyQwen3MoeModel,
# QEffQwen3MoeAttention: QEffPrefillOnlyQwen3MoeAttention,
}


class PrefillOnlyChunkedTransform(ModuleMappingTransform):
# breakpoint()
_module_mapping = {
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
QEffQwen3MoeSparseMoeBlock: QEffPrefillQwen3MoeSparseMoeBlock,
}


class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
# breakpoint()
_module_mapping = {
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
QEffPrefillQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
}


class RevertPrefillOnlyTransform(ModuleMappingTransform):
# breakpoint()
_module_mapping = {
**{v: k for k, v in PrefillOnlyTransform._module_mapping.items()},
**{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()},
Expand Down
173 changes: 147 additions & 26 deletions QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
):
# breakpoint()
key_states = repeat_kv(key, module.num_key_value_groups)

value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
Expand All @@ -118,25 +118,13 @@ def eager_attention_forward(
return attn_output, attn_weights


class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __qeff_init__(self):
self.gate_proj_w = []
self.up_proj_w = []
self.down_proj_w = []
with torch.no_grad():
for e in range(self.num_experts):
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
self.gate_proj_w = torch.stack(self.gate_proj_w)
self.up_proj_w = torch.stack(self.up_proj_w)
self.down_proj_w = torch.stack(self.down_proj_w)

def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
class QEffPrefillQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# breakpoint()
B, S, H = hidden_states.shape
T = B * S
x = hidden_states.view(T, H)

# breakpoint()
router_logits = self.gate(x) # [T, E]
prob = F.softmax(router_logits, -1, dtype=torch.float)
top_w, top_i = torch.topk(prob, self.top_k, -1)
Expand All @@ -145,27 +133,39 @@ def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.
top_w = top_w.to(x.dtype)
masked_logits = torch.zeros_like(router_logits)
masked_logits.scatter_(1, top_i, top_w)

# Routing weights for each expert [T, E]
routing_weights = masked_logits

# ────────────────── allocate the output tensor ─────
expert_out = x.new_zeros((T, H)) # accumulation buffer

# ───────────────────────── Expert computation loop ─────────────────────────────
for e in range(self.num_experts):
routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I]
W_d = self.experts[e].down_proj # [I, H]
gate = W_g(x) # [T, I]
up = W_u(x) # [T, I]
down = W_d(up * self.experts[e].act_fn(gate)) # [T, H]

W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I]
W_d = self.experts[e].down_proj.weight.T # [I, H]
gate = x @ W_g # [T, I]
up = x @ W_u # [T, I]
down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H]
masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out))
expert_out += masked_down
return expert_out.view(B, S, H), router_logits


class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __qeff_init__(self):
self.gate_proj_w = []
self.up_proj_w = []
self.down_proj_w = []
with torch.no_grad():
for e in range(self.num_experts):
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
self.gate_proj_w = torch.stack(self.gate_proj_w)
self.up_proj_w = torch.stack(self.up_proj_w)
self.down_proj_w = torch.stack(self.down_proj_w)

def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# breakpoint()
B, S, H = hidden_states.shape
T = B * S
hidden_states = hidden_states.view(T, H)
Expand Down Expand Up @@ -240,6 +240,56 @@ def forward(
return attn_output, attn_weights


class QEffPrefillOnlyQwen3MoeAttention(Qwen3MoeAttention):
def __qeff_init__(self):
self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
breakpoint()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface = eager_attention_forward

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=self.scaling,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)

return attn_output, attn_weights


class QEffQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
def forward(
self,
Expand Down Expand Up @@ -370,6 +420,77 @@ def forward(
)


class QEffPrefillOnlyQwen3MoeModel(Qwen3MoeModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
batch_index: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> MoeModelOutputWithPast:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]

past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length)

hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)

hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_values = past_key_values.to_legacy_cache()

return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
)


class QEffQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
def forward(
self,
Expand Down
Loading
Loading