Skip to content

Commit c75cc2d

Browse files
committed
Refactor wan
- Addressed spec file issue - Adding compile only flag by default - Refactor of common utils Signed-off-by: vtirumal <[email protected]>
1 parent 9602612 commit c75cc2d

File tree

15 files changed

+943
-857
lines changed

15 files changed

+943
-857
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -526,23 +526,18 @@ class FeatureNotAvailableError(Exception):
526526
)
527527
try:
528528
subprocess.run(command, capture_output=True, check=True)
529-
# TODO: remove once compiler fix exit code (failing with error: Benchmark run failed, exit code 1)
530529
except subprocess.CalledProcessError as e:
531-
# Check if exit code is 1 and programqpc.bin exists in qpc_path
532-
if e.returncode == 1 and qpc_path and (qpc_path / "programqpc.bin").is_file():
533-
logger.warning("Compiler exited with code 1, but programqpc.bin exists. Continuing...")
534-
else:
535-
raise RuntimeError(
536-
"\n".join(
537-
[
538-
"Compilation failed!",
539-
f"Compiler command: {e.cmd}",
540-
f"Compiler exitcode: {e.returncode}",
541-
"Compiler stderr:",
542-
e.stderr.decode(),
543-
]
544-
)
530+
raise RuntimeError(
531+
"\n".join(
532+
[
533+
"Compilation failed!",
534+
f"Compiler command: {e.cmd}",
535+
f"Compiler exitcode: {e.returncode}",
536+
"Compiler stderr:",
537+
e.stderr.decode(),
538+
]
545539
)
540+
)
546541
# Dump JSON file with hashed parameters
547542
hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
548543
create_json(hashed_compile_params_path, compile_hash_params)

QEfficient/diffusers/models/transformers/transformer_wan.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from diffusers.utils import set_weights_and_activate_adapters
2828

29-
from QEfficient.diffusers.pipelines.pipeline_utils import (
29+
from QEfficient.diffusers.pipelines.modeling_utils import (
3030
compute_blocked_attention,
3131
get_attention_blocking_config,
3232
)
@@ -226,7 +226,7 @@ def forward(
226226
1. Patch embedding of input
227227
2. Rotary embedding preparation
228228
3. Cross-attention with encoder states
229-
4. Transformer block processing (with optional gradient checkpointing)
229+
4. Transformer block processing
230230
5. Output normalization and projection
231231
232232
Args:
@@ -254,17 +254,9 @@ def forward(
254254
if encoder_hidden_states_image is not None:
255255
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
256256

257-
# Process through transformer blocks
258-
if torch.is_grad_enabled() and self.gradient_checkpointing:
259-
# Use gradient checkpointing to save memory during training
260-
for block in self.blocks:
261-
hidden_states = self._gradient_checkpointing_func(
262-
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
263-
)
264-
else:
265-
# Standard forward pass
266-
for block in self.blocks:
267-
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
257+
# Standard forward pass
258+
for block in self.blocks:
259+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
268260

269261
# Output normalization, projection & unpatchify
270262
if temb.ndim == 3:
Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,46 @@
11
{
2-
"description": "Default configuration for Wan pipeline",
3-
"model_type": "wan",
2+
"description": "Default configuration for Wan unified transformer",
43
"modules": {
54
"transformer": {
65
"specializations": [
7-
{
8-
"batch_size": "1",
9-
"num_channels": "16",
10-
"num_frames": "21",
11-
"latent_height": "24",
12-
"latent_width": "40",
13-
"steps": "1",
14-
"sequence_length": "512",
15-
"cl": "5040",
16-
"model_type": 1,
17-
},
18-
{
19-
"batch_size": "1",
20-
"num_channels": "16",
21-
"num_frames": "21",
22-
"latent_height": "24",
23-
"latent_width": "40",
24-
"steps": "1",
25-
"sequence_length": "512",
26-
"cl": "5040",
27-
"model_type": 2,
28-
},
29-
],
30-
"compilation": {
31-
"onnx_path": null,
32-
"compile_dir": null,
33-
"mdp_ts_num_devices": 16,
34-
"mxfp6_matmul": true,
35-
"convert_to_fp16": true,
36-
"aic_num_cores": 16,
37-
"mos": 1,
38-
"mdts_mos": 1,
39-
},
40-
"execute": {"device_ids": null},
6+
{
7+
"batch_size": "1",
8+
"num_channels": "16",
9+
"num_frames": "21",
10+
"latent_height": "24",
11+
"latent_width": "40",
12+
"steps": "1",
13+
"sequence_length": "512",
14+
"cl": "5040",
15+
"model_type": 1
16+
},
17+
{
18+
"batch_size": "1",
19+
"num_channels": "16",
20+
"num_frames": "21",
21+
"latent_height": "24",
22+
"latent_width": "40",
23+
"steps": "1",
24+
"sequence_length": "512",
25+
"cl": "5040",
26+
"model_type": 2
27+
}
28+
],
29+
"compilation":
30+
{
31+
"onnx_path": null,
32+
"compile_dir": null,
33+
"mdp_ts_num_devices": 16,
34+
"mxfp6_matmul": true,
35+
"convert_to_fp16": true,
36+
"aic_num_cores": 16,
37+
"mos": 1,
38+
"mdts_mos": 1
39+
},
40+
"execute":
41+
{
42+
"device_ids": null
43+
}
4144
}
4245
},
4346
}

0 commit comments

Comments
 (0)