Skip to content

Commit a0a0417

Browse files
committed
Refactored to address review comments
Signed-off-by: vtirumal <[email protected]>
1 parent 9e43a58 commit a0a0417

File tree

14 files changed

+254
-222
lines changed

14 files changed

+254
-222
lines changed

QEfficient/diffusers/pipelines/modeling_utils.py renamed to QEfficient/diffusers/models/modeling_utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def apply_head_blocking(
4040
q: torch.FloatTensor,
4141
k: torch.FloatTensor,
4242
v: torch.FloatTensor,
43+
head_block_size: int,
4344
attention_mask: Optional[torch.FloatTensor] = None,
4445
) -> torch.FloatTensor:
4546
"""
@@ -62,7 +63,6 @@ def apply_head_blocking(
6263
scale_factor = 1.0 / math.sqrt(DH)
6364

6465
# Get head blocking configuration
65-
_, head_block_size, _, _ = get_attention_blocking_config()
6666
head_block_size = head_block_size or NH
6767
num_head_blocks = math.ceil(NH / head_block_size)
6868

@@ -107,6 +107,8 @@ def apply_kv_blocking(
107107
q: torch.FloatTensor,
108108
k: torch.FloatTensor,
109109
v: torch.FloatTensor,
110+
head_block_size: int,
111+
num_kv_blocks: int,
110112
attention_mask: Optional[torch.FloatTensor] = None,
111113
) -> torch.FloatTensor:
112114
"""
@@ -129,7 +131,6 @@ def apply_kv_blocking(
129131
scale_factor = 1.0 / math.sqrt(DH)
130132

131133
# Get blocking configuration
132-
_, head_block_size, num_kv_blocks, _ = get_attention_blocking_config()
133134
head_block_size = head_block_size or NH
134135
num_kv_blocks = num_kv_blocks or CL
135136
num_head_blocks = math.ceil(NH / head_block_size)
@@ -210,6 +211,8 @@ def apply_q_blocking(
210211
q: torch.FloatTensor,
211212
k: torch.FloatTensor,
212213
v: torch.FloatTensor,
214+
head_block_size: int,
215+
num_q_blocks: int,
213216
attention_mask: Optional[torch.FloatTensor] = None,
214217
) -> torch.FloatTensor:
215218
"""
@@ -232,7 +235,6 @@ def apply_q_blocking(
232235
scale_factor = 1.0 / math.sqrt(DH)
233236

234237
# Get blocking configuration
235-
_, head_block_size, _, num_q_blocks = get_attention_blocking_config()
236238
head_block_size = head_block_size or NH
237239
num_q_blocks = num_q_blocks or CL
238240
num_head_blocks = math.ceil(NH / head_block_size)
@@ -292,6 +294,9 @@ def apply_qkv_blocking(
292294
q: torch.FloatTensor,
293295
k: torch.FloatTensor,
294296
v: torch.FloatTensor,
297+
head_block_size: int,
298+
num_kv_blocks: int,
299+
num_q_blocks: int,
295300
attention_mask: Optional[torch.FloatTensor] = None,
296301
) -> torch.FloatTensor:
297302
"""
@@ -313,7 +318,6 @@ def apply_qkv_blocking(
313318
scale_factor = 1.0 / math.sqrt(DH)
314319

315320
# Get blocking configuration from environment variables
316-
_, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config()
317321
head_block_size = head_block_size or NH
318322
num_kv_blocks = num_kv_blocks or CL
319323
num_q_blocks = num_q_blocks or CL
@@ -420,6 +424,9 @@ def compute_blocked_attention(
420424
q: torch.FloatTensor,
421425
k: torch.FloatTensor,
422426
v: torch.FloatTensor,
427+
head_block_size: int,
428+
num_kv_blocks: int,
429+
num_q_blocks: int,
423430
blocking_mode: str = "default",
424431
attention_mask: Optional[torch.FloatTensor] = None,
425432
) -> torch.FloatTensor:
@@ -430,17 +437,20 @@ def compute_blocked_attention(
430437
q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH)
431438
k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH)
432439
v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH)
440+
head_block_size (int) : Head blocking size
441+
num_kv_blocks (int) : Number of KV blocks
442+
num_q_blocks (int) : Number of Q blocks
433443
blocking_mode (str): Blocking strategy ('kv', 'q', 'qkv', 'default')
434444
attention_mask (Optional[torch.FloatTensor]): Attention mask tensor
435445
436446
Returns:
437447
torch.FloatTensor: Attention output of shape (BS, NH, CL, DH)
438448
"""
439449
if blocking_mode == "kv":
440-
return apply_kv_blocking(q, k, v, attention_mask)
450+
return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask)
441451
elif blocking_mode == "q":
442-
return apply_q_blocking(q, k, v, attention_mask)
452+
return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask)
443453
elif blocking_mode == "qkv":
444-
return apply_qkv_blocking(q, k, v, attention_mask)
454+
return apply_qkv_blocking(q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask)
445455
else: # default
446-
return apply_head_blocking(q, k, v, attention_mask)
456+
return apply_head_blocking(q, k, v, head_block_size, attention_mask)

QEfficient/diffusers/models/transformers/transformer_wan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from diffusers.utils import set_weights_and_activate_adapters
2828

29-
from QEfficient.diffusers.pipelines.modeling_utils import (
29+
from QEfficient.diffusers.models.modeling_utils import (
3030
compute_blocked_attention,
3131
get_attention_blocking_config,
3232
)
@@ -113,12 +113,15 @@ def apply_rotary_emb(
113113
key = apply_rotary_emb(key, *rotary_emb)
114114

115115
# Get blocking configuration
116-
blocking_mode, _, _, _ = get_attention_blocking_config()
116+
blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config()
117117
# Apply blocking using pipeline_utils
118118
hidden_states = compute_blocked_attention(
119119
query.transpose(1, 2),
120120
key.transpose(1, 2),
121121
value.transpose(1, 2),
122+
head_block_size,
123+
num_kv_blocks,
124+
num_q_blocks,
122125
blocking_mode=blocking_mode,
123126
attention_mask=attention_mask,
124127
)
Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,36 @@
11
{
2-
"description": "Default configuration for Wan unified transformer",
3-
"modules": {
4-
"transformer": {
5-
"specializations": [
6-
{
7-
"batch_size": "1",
8-
"num_channels": "16",
9-
"num_frames": "21",
10-
"latent_height": "24",
11-
"latent_width": "40",
12-
"steps": "1",
13-
"sequence_length": "512",
14-
"cl": "5040",
15-
"model_type": 1
16-
},
17-
{
18-
"batch_size": "1",
19-
"num_channels": "16",
20-
"num_frames": "21",
21-
"latent_height": "24",
22-
"latent_width": "40",
23-
"steps": "1",
24-
"sequence_length": "512",
25-
"cl": "5040",
26-
"model_type": 2
27-
}
28-
],
29-
"compilation":
30-
{
31-
"onnx_path": null,
32-
"compile_dir": null,
33-
"mdp_ts_num_devices": 16,
34-
"mxfp6_matmul": true,
35-
"convert_to_fp16": true,
36-
"aic_num_cores": 16,
37-
"mos": 1,
38-
"mdts_mos": 1
39-
},
40-
"execute":
41-
{
42-
"device_ids": null
43-
}
44-
}
45-
},
46-
}
2+
"description": "Default configuration for Wan pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)",
3+
"modules": {
4+
"transformer": {
5+
"specializations": [
6+
{
7+
"batch_size": "1",
8+
"num_channels": "16",
9+
"steps": "1",
10+
"sequence_length": "512",
11+
"model_type": 1
12+
},
13+
{
14+
"batch_size": "1",
15+
"num_channels": "16",
16+
"steps": "1",
17+
"sequence_length": "512",
18+
"model_type": 2
19+
}
20+
],
21+
"compilation": {
22+
"onnx_path": null,
23+
"compile_dir": null,
24+
"mdp_ts_num_devices": 16,
25+
"mxfp6_matmul": true,
26+
"convert_to_fp16": true,
27+
"aic_num_cores": 16,
28+
"mos": 1,
29+
"mdts_mos": 1
30+
},
31+
"execute": {
32+
"device_ids": null
33+
}
34+
}
35+
}
36+
}

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def export(
458458
"""
459459

460460
if use_onnx_subfunctions:
461-
export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}}
461+
export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}, "use_onnx_subfunctions":True}
462462

463463
# Sort _use_default_values in config to ensure consistent hash generation during export
464464
self.model.config["_use_default_values"].sort()
@@ -591,7 +591,7 @@ def export(
591591
output_names: List[str],
592592
dynamic_axes: Dict,
593593
export_dir: str = None,
594-
export_kwargs: Dict = None,
594+
export_kwargs: Dict = {},
595595
use_onnx_subfunctions: bool = False,
596596
) -> str:
597597
"""Export the Wan transformer model to ONNX format.
@@ -607,14 +607,8 @@ def export(
607607
Returns:
608608
str: Path to the exported ONNX model
609609
"""
610-
if export_kwargs is None:
611-
export_kwargs = {}
612-
613610
if use_onnx_subfunctions:
614-
export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}}
615-
616-
# torch patch to export onnx with subfunction
617-
apply_torch_patches() # TODO: Moving to _export is better
611+
export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}, "use_onnx_subfunctions":True}
618612

619613
return self._export(
620614
example_inputs=inputs,
@@ -634,10 +628,3 @@ def compile(self, specializations, **compiler_options) -> None:
634628
**compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
635629
"""
636630
self._compile(specializations=specializations, **compiler_options)
637-
638-
@property
639-
def model_name(self) -> str:
640-
mname = self.model.__class__.__name__
641-
if mname.startswith("QEff") or mname.startswith("QEFF"):
642-
mname = mname[4:]
643-
return mname

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_fac
3939
return cl, latent_height, latent_width
4040

4141

42-
def calculate_latent_dimensions(
42+
def calculate_latent_dimensions_with_frames(
4343
height: int,
4444
width: int,
4545
num_frames: int,
@@ -49,22 +49,27 @@ def calculate_latent_dimensions(
4949
patch_width: int,
5050
) -> int:
5151
"""
52-
Calculate the latent dimensions, Compressed latent dimension (cl) for transformer buffer allocation.
52+
Calculate the latent dimensions for video generation models.
5353
54-
This method computes the compressed sequence length (cl) that the transformer
55-
will process, based on the target video dimensions, VAE scale factors, and
56-
patch sizes. This is crucial for proper buffer allocation in QAIC inference.
54+
This method computes the compressed sequence length (cl),
55+
Latent height, Latent width , Latent frames based on the
56+
target video dimensions, VAE scale factors, and patch sizes.
5757
5858
Args:
5959
height (int): Target video height in pixels
6060
width (int): Target video width in pixels
6161
num_frames (int): Target video frames in pixels
62+
vae_scale_factor_spatial (int): spatial vae_scale_factor from model config
63+
vae_scale_factor_temporal (int): temporal vae_scale_factor from model config
64+
patch_height (int): patch_height from model config
65+
patch_width (int): patch_width from model config
6266
6367
Returns:
6468
tuple: (cl, latent_height, latent_width)
6569
- cl (int): Compressed latent dimension for transformer input
6670
- latent_height (int): Height in latent space
6771
- latent_width (int): Width in latent space
72+
- latent_frames (int): frames in latent space
6873
6974
Mathematical Formula:
7075
latent_height = height // vae_scale_factor_spatial

0 commit comments

Comments
 (0)