Skip to content

Commit 2f5c3e3

Browse files
committed
Cleaning of QwenMoe chunking script
Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent 26061fb commit 2f5c3e3

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

examples/qwen3moe_disagg_mode_with_chunking.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,25 @@
1818
prompt = """
1919
Explain quantum computing in simple terms.
2020
"""
21-
config = AutoConfig.from_pretrained(model_id, num_hidden_layers=2)
22-
tokenizer = AutoTokenizer.from_pretrained(model_id, num_hidden_layers=2)
21+
config = AutoConfig.from_pretrained(model_id)
22+
tokenizer = AutoTokenizer.from_pretrained(model_id)
2323
PREFILL_SEQ_LEN = 128
2424
CTX_LEN = 128 * 3
2525

26-
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
27-
breakpoint()
26+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
2827
decode_qpc_path = qeff_model.compile(
2928
prefill_seq_len=1,
3029
ctx_len=CTX_LEN,
3130
num_cores=16,
3231
mxfp6_matmul=True,
3332
mxint8_kv_cache=True,
34-
num_devices=2,
35-
split_retained_state_io=True,
33+
num_devices=1,
3634
mos=1,
3735
aic_enable_depth_first=True,
3836
num_speculative_tokens=None,
3937
offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
4038
retain_full_kv=True,
4139
)
42-
breakpoint()
4340

4441
# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68
4542
# prefill_qpc_path = "/home/dipankar/.cache/qeff_models/Qwen3MoeForCausalLM/Qwen3MoeForCausalLM-2fff95dd3d8e1907/qpc-0d9874dc75da1555/qpc"
@@ -60,7 +57,7 @@
6057
# use_onnx_subfunctions=True,
6158
)
6259

63-
breakpoint()
60+
6461
inputs = tokenizer(prompt, return_tensors="np", padding=True)
6562
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
6663
generation_len = CTX_LEN - position_ids.max()
@@ -74,9 +71,9 @@
7471
inputs.pop("past_key_values", None)
7572
inputs = {k: v.detach().numpy() for k, v in inputs.items()}
7673

77-
breakpoint()
7874

7975
prefill_session = QAICInferenceSession(prefill_qpc_path)
76+
decode_session = QAICInferenceSession(decode_qpc_path)
8077

8178
all_outputs = []
8279
for i in range(num_chunks):
@@ -86,16 +83,12 @@
8683
ins = time.time()
8784
qpc_out = prefill_session.run(chunk_inputs)
8885
print(f"time for this run={time.time() - ins}")
89-
breakpoint()
9086
for i in range(config.num_hidden_layers):
9187
inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
9288
inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
9389

9490
all_outputs.append(np.argmax(qpc_out["logits"]))
95-
prefill_session.deactivate()
96-
decode_session = QAICInferenceSession(decode_qpc_path)
97-
breakpoint()
98-
# decode_session.activate()
91+
9992
decode_inputs = {
10093
"input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
10194
"position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,

0 commit comments

Comments
 (0)