diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index b507363c3..3c9f68efd 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -29,6 +29,7 @@ ) from QEfficient.compile.compile_helper import compile from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline +from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -55,6 +56,7 @@ "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", "QEffFluxPipeline", + "QEffWanPipeline", ] diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md index 40d45e984..4777d48fb 100644 --- a/QEfficient/diffusers/README.md +++ b/QEfficient/diffusers/README.md @@ -62,6 +62,7 @@ pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl ## šŸŽÆ Supported Models - āœ… [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) +- āœ… [`lightx2v/Wan2.2-Lightning`](https://huggingface.co/lightx2v/Wan2.2-Lightning) --- @@ -83,7 +84,6 @@ We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING ## šŸ™ Acknowledgments - **HuggingFace Diffusers**: For the excellent foundation library -- **Stability AI**: For the amazing Stable Diffusion models --- ## šŸ“ž Support diff --git a/QEfficient/diffusers/models/modeling_utils.py b/QEfficient/diffusers/models/modeling_utils.py new file mode 100644 index 000000000..59727be2d --- /dev/null +++ b/QEfficient/diffusers/models/modeling_utils.py @@ -0,0 +1,456 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import math +import os +from typing import Optional + +import torch + + +def get_attention_blocking_config(): + """ + Get attention blocking configuration from environment variables. + + Returns: + tuple: (blocking_mode, head_block_size, num_kv_blocks, num_q_blocks) + - blocking_mode (str): The blocking strategy ('kv', 'q', 'qkv', 'default') + - head_block_size (int or None): Number of attention heads per block + - num_kv_blocks (int or None): Number of key-value blocks + - num_q_blocks (int or None): Number of query blocks + """ + mode = os.environ.get("ATTENTION_BLOCKING_MODE", "default").lower() + head_block_size = int(os.environ.get("head_block_size", 0)) or None + num_kv_blocks = int(os.environ.get("num_kv_blocks", 0)) or None + num_q_blocks = int(os.environ.get("num_q_blocks", 0)) or None + + # Validate blocking mode + valid_modes = ["kv", "qkv", "q", "default"] + if mode not in valid_modes: + raise ValueError(f"Invalid ATTENTION_BLOCKING_MODE: {mode}. Must be one of {valid_modes}") + + return mode, head_block_size, num_kv_blocks, num_q_blocks + + +def apply_head_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with head-only blocking (default mode). + + This method processes attention heads in blocks while computing full attention + matrices for each head block. It's less memory-efficient than other blocking + modes but simpler and faster for moderate sequence lengths. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get head blocking configuration + head_block_size = head_block_size or NH + num_head_blocks = math.ceil(NH / head_block_size) + + # Optimization: Handle small sequences with standard attention + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + outputs = [] + + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + + # Extract head blocks + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + + # Compute full attention matrix for this head block + qkblock = torch.matmul(q_g, k_g.transpose(-2, -1)) * scale_factor + + # Standard softmax computation + probs = torch.softmax(qkblock, dim=-1) + + # Compute attention output + output_blocks = torch.matmul(probs, v_g) + outputs.append(output_blocks) + + # Concatenate all head blocks along head dimension + out = torch.cat(outputs, dim=1) # (BS, NH, CL, DH) + return out + + +def apply_kv_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_kv_blocks: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with Key-Value blocking and head blocking. + + This method processes key-value pairs in blocks while keeping queries intact. + It uses online softmax to maintain numerical stability and reduce memory usage + compared to computing full attention matrices. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get blocking configuration + head_block_size = head_block_size or NH + num_kv_blocks = num_kv_blocks or CL + num_head_blocks = math.ceil(NH / head_block_size) + block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)] + + # Handle small sequences with standard attention + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + head_outputs = [] + + # Process each head block + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + num_h = h_end - h_start + + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + + # Initialize online softmax statistics + running_exp_sum = torch.zeros((BS, num_h, CL), device=q.device, dtype=q.dtype) + running_max = torch.full((BS, num_h, CL), float("-inf"), device=q.device, dtype=q.dtype) + output_blocks = torch.zeros_like(q_g) + + # Process K,V in blocks using online softmax + for kv_block_idx in range(num_kv_blocks): + ki = block_positions[kv_block_idx] + + # Calculate KV block size + if kv_block_idx == num_kv_blocks - 1: + real_kv_len = CL - ki + else: + real_kv_len = block_positions[kv_block_idx + 1] - ki + + k_block = k_g[:, :, ki : ki + real_kv_len, :] + v_block = v_g[:, :, ki : ki + real_kv_len, :] + + # Compute attention scores for current KV block + qkblock = torch.matmul(q_g, k_block.transpose(-2, -1)) * scale_factor + + # Online softmax: Update running maximum + prev_max = running_max.clone() + running_max = torch.maximum(prev_max, torch.max(qkblock, dim=-1)[0]) + + # Calculate numerical stability adjustments + delta_max = prev_max - running_max + curr_exp = torch.exp(qkblock - running_max.unsqueeze(-1)) + + # Update running sum of exponentials + prev_exp_sum = running_exp_sum.clone() + curr_exp_sum = torch.einsum("bhqk->bhq", curr_exp) + running_exp_sum = prev_exp_sum * torch.exp(delta_max) + curr_exp_sum + + # Compute normalized attention weights + inv_running_exp_sum = 1.0 / running_exp_sum + softmax_qkblock = curr_exp * inv_running_exp_sum.unsqueeze(-1) + + # Update output with rescaling + prev_out = output_blocks.clone() + rescale_factor = (prev_exp_sum * inv_running_exp_sum) * torch.exp(delta_max) + output_blocks = rescale_factor.unsqueeze(-1) * prev_out + torch.matmul(softmax_qkblock, v_block) + + head_outputs.append(output_blocks) + + out = torch.cat(head_outputs, dim=1) # (BS, NH, CL, DH) + return out + + +def apply_q_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_q_blocks: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with Query blocking and head blocking. + + This method processes query tokens in blocks while keeping key-value pairs intact. + It's useful when the sequence length is large but memory constraints are primarily + due to the query dimension. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get blocking configuration + head_block_size = head_block_size or NH + num_q_blocks = num_q_blocks or CL + num_head_blocks = math.ceil(NH / head_block_size) + q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)] + + # Handle small sequences with standard attention + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + head_outputs = [] + + # Process each head block + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + + q_output_list = [] + + # Process queries in blocks + for q_block_idx in range(num_q_blocks): + qi = q_block_positions[q_block_idx] + + # Calculate Q block size + if q_block_idx == num_q_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + q_block = q_g[:, :, qi : qi + real_q_len, :] + + # Compute attention for this query block against all keys + scores = torch.matmul(q_block, k_g.transpose(-2, -1)) * scale_factor + probs = torch.softmax(scores, dim=-1) + out_block = torch.matmul(probs, v_g) + + q_output_list.append(out_block) + + # Concatenate query blocks + head_output = torch.cat(q_output_list, dim=2) + head_outputs.append(head_output) + + out = torch.cat(head_outputs, dim=1) # (BS, NH, CL, DH) + return out + + +def apply_qkv_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_kv_blocks: int, + num_q_blocks: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with combined Query, Key, Value blocking and head blocking. + + This method implements the most memory-efficient attention computation by blocking + along all three dimensions: heads, queries, and key-values. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get blocking configuration from environment variables + head_block_size = head_block_size or NH + num_kv_blocks = num_kv_blocks or CL + num_q_blocks = num_q_blocks or CL + num_head_blocks = math.ceil(NH / head_block_size) + + # Calculate block positions for even distribution + kv_block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)] + q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)] + + # Optimization: Use standard attention for small sequences + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + head_outputs = [] + + # Process attention heads in blocks to reduce memory usage + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + num_h = h_end - h_start + + # Extract current head block + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + q_output_list = [] + + # Process queries in blocks within each head block + for q_block_idx in range(num_q_blocks): + qi = q_block_positions[q_block_idx] + + # Calculate actual Q block size (handle remainder for last block) + if q_block_idx == num_q_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + q_block = q_g[:, :, qi : qi + real_q_len, :] + + # Initialize online softmax statistics for this Q block + running_exp_sum = torch.zeros((BS, num_h, real_q_len), device=q.device, dtype=q.dtype) + running_max = torch.full((BS, num_h, real_q_len), float("-inf"), device=q.device, dtype=q.dtype) + output_blocks = torch.zeros((BS, num_h, real_q_len, DH), device=q.device, dtype=q.dtype) + + # Process K,V in blocks for this Q block (online softmax) + for kv_block_idx in range(num_kv_blocks): + ki = kv_block_positions[kv_block_idx] + + # Calculate actual KV block size + if kv_block_idx == num_kv_blocks - 1: + real_kv_len = CL - ki + else: + real_kv_len = kv_block_positions[kv_block_idx + 1] - ki + + k_block = k_g[:, :, ki : ki + real_kv_len, :] + v_block = v_g[:, :, ki : ki + real_kv_len, :] + + # Compute attention scores for current Q-K block + qkblock = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale_factor + + # Online softmax: Update running maximum + prev_max = running_max.clone() + if qkblock.shape[-1] == 0: + running_max = prev_max + else: + running_max = torch.maximum(prev_max, torch.max(qkblock, dim=-1)[0]) + + # Calculate adjustment factor for numerical stability + delta_max = prev_max - running_max + curr_exp = torch.exp(qkblock - running_max.unsqueeze(-1)) + + # Online softmax: Update running sum of exponentials + prev_exp_sum = running_exp_sum.clone() + curr_exp_sum = torch.einsum("bhqk->bhq", curr_exp) + running_exp_sum = prev_exp_sum * torch.exp(delta_max) + curr_exp_sum + + # Compute normalized attention weights for this block + inv_running_exp_sum = 1.0 / running_exp_sum + softmax_qkblock = curr_exp * inv_running_exp_sum.unsqueeze(-1) + + # Online softmax: Update output with rescaling of previous blocks + prev_out = output_blocks.clone() + rescale_factor = (prev_exp_sum * inv_running_exp_sum) * torch.exp(delta_max) + output_blocks = rescale_factor.unsqueeze(-1) * prev_out + torch.matmul(softmax_qkblock, v_block) + + q_output_list.append(output_blocks) + + # Concatenate all Q blocks for this head block + head_output = torch.cat(q_output_list, dim=2) + head_outputs.append(head_output) + + # Concatenate all head blocks + out = torch.cat(head_outputs, dim=1) + return out + + +def compute_blocked_attention( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_kv_blocks: int, + num_q_blocks: int, + blocking_mode: str = "default", + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Main dispatcher function for different attention blocking strategies. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + head_block_size (int) : Head blocking size + num_kv_blocks (int) : Number of KV blocks + num_q_blocks (int) : Number of Q blocks + blocking_mode (str): Blocking strategy ('kv', 'q', 'qkv', 'default') + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + if blocking_mode == "kv": + return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask) + elif blocking_mode == "q": + return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask) + elif blocking_mode == "qkv": + return apply_qkv_blocking(q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask) + else: # default + return apply_head_blocking(q, k, v, head_block_size, attention_mask) diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py index d3c84ee63..4fb5c3f12 100644 --- a/QEfficient/diffusers/models/pytorch_transforms.py +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -13,6 +13,7 @@ FluxTransformer2DModel, FluxTransformerBlock, ) +from diffusers.models.transformers.transformer_wan import WanAttention, WanAttnProcessor, WanTransformer3DModel from torch import nn from QEfficient.base.pytorch_transforms import ModuleMappingTransform @@ -29,6 +30,11 @@ QEffFluxTransformer2DModel, QEffFluxTransformerBlock, ) +from QEfficient.diffusers.models.transformers.transformer_wan import ( + QEffWanAttention, + QEffWanAttnProcessor, + QEffWanTransformer3DModel, +) class CustomOpsTransform(ModuleMappingTransform): @@ -45,6 +51,9 @@ class AttentionTransform(ModuleMappingTransform): FluxTransformer2DModel: QEffFluxTransformer2DModel, FluxAttention: QEffFluxAttention, FluxAttnProcessor: QEffFluxAttnProcessor, + WanAttnProcessor: QEffWanAttnProcessor, + WanAttention: QEffWanAttention, + WanTransformer3DModel: QEffWanTransformer3DModel, } diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py new file mode 100644 index 000000000..31d3be2ce --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -0,0 +1,291 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +""" +QEfficient WAN Transformer Implementation + +This module provides optimized implementations of WAN transformers +with various attention blocking strategies for memory efficiency and performance optimization. +The implementation includes multiple blocking modes: head-only, KV-blocking, Q-blocking, +and combined QKV-blocking. +""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_wan import ( + WanAttention, + WanAttnProcessor, + WanTransformer3DModel, + _get_qkv_projections, +) +from diffusers.utils import set_weights_and_activate_adapters + +from QEfficient.diffusers.models.modeling_utils import ( + compute_blocked_attention, + get_attention_blocking_config, +) + + +class QEffWanAttnProcessor(WanAttnProcessor): + """ + QEfficient WAN Attention Processor with Memory-Efficient Blocking Strategies. + + This processor implements multiple attention blocking modes to reduce memory usage + and enable processing of longer sequences. It supports: + - Head blocking: Process attention heads in chunks + - KV blocking: Process key-value pairs in blocks + - Q blocking: Process query tokens in blocks + - QKV blocking: Combined query, key, and value blocking + + Environment Variables: + ATTENTION_BLOCKING_MODE: Controls blocking strategy ('kv', 'q', 'qkv', 'default') + head_block_size: Number of attention heads to process per block + num_kv_blocks: Number of blocks for key-value processing + num_q_blocks: Number of blocks for query processing + """ + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Main attention processing pipeline with support for multiple blocking strategies. + + This method orchestrates the complete attention computation including: + 1. QKV projection and normalization + 2. Rotary position embedding application + 3. Attention computation with selected blocking strategy + 4. Output projection + + Args: + attn (WanAttention): The attention module instance + hidden_states (torch.Tensor): Input hidden states + encoder_hidden_states (Optional[torch.Tensor]): Cross-attention encoder states + attention_mask (Optional[torch.Tensor]): Attention mask + rotary_emb (Optional[Tuple[torch.Tensor, torch.Tensor]]): Rotary embeddings (cos, sin) + + Returns: + torch.Tensor: Processed hidden states after attention + """ + # Project inputs to query, key, value + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + # Apply layer normalization to queries and keys + query = attn.norm_q(query) + key = attn.norm_k(key) + + # Reshape for multi-head attention: (batch, seq, dim) -> (batch, seq, heads, head_dim) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # Apply rotary position embeddings if provided + if rotary_emb is not None: + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + """Apply rotary position embeddings to the input tensor.""" + # Split into real and imaginary parts for complex rotation + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2].type_as(hidden_states) + sin = freqs_sin[..., 1::2].type_as(hidden_states) + + # Apply rotation: (x1 + ix2) * (cos + isin) = (x1*cos - x2*sin) + i(x1*sin + x2*cos) + real = x1 * cos - x2 * sin + img = x1 * sin + x2 * cos + x_rot = torch.stack([real, img], dim=-1) + return x_rot.flatten(-2).type_as(hidden_states) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # Get blocking configuration + blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config() + # Apply blocking using pipeline_utils + hidden_states = compute_blocked_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + head_block_size, + num_kv_blocks, + num_q_blocks, + blocking_mode=blocking_mode, + attention_mask=attention_mask, + ) + + # Reshape back to original format + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + # Apply output projection layers + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class QEffWanAttention(WanAttention): + """ + QEfficient WAN Attention module with optimized processor. + + This class extends the base WanAttention with QEfficient optimizations, + automatically setting up the QEffWanAttnProcessor for memory-efficient + attention computation. + """ + + def __qeff_init__(self): + """Initialize the QEfficient attention processor.""" + processor = QEffWanAttnProcessor() + self.processor = processor + + +class QEffWanTransformer3DModel(WanTransformer3DModel): + """ + QEfficient 3D WAN Transformer Model with adapter support. + + This model extends the base WanTransformer3DModel with QEfficient optimizations. + """ + + def set_adapters( + self, + adapter_names: Union[List[str], str], + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, + ): + """ + Set the currently active adapters for use in the diffusion network. + + This method manages PEFT adapters, allowing for efficient fine-tuning + and model customization without modifying the base model parameters. + + Args: + adapter_names (Union[List[str], str]): Names of adapters to activate + weights (Optional[Union[float, Dict, List[float], List[Dict], List[None]]]): + Weights for each adapter. Can be: + - Single float: Applied to all adapters + - List of floats: One weight per adapter + - Dict: Detailed weight configuration + - None: Uses default weight of 1.0 + + Raises: + ValueError: If adapter names and weights lists have different lengths + + Note: + - Adapters enable parameter-efficient fine-tuning + - Multiple adapters can be active simultaneously with different weights + - Weights control the influence of each adapter on the model output + """ + # Normalize adapter names to list format + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # Expand weights into a list, one entry per adapter + # Examples for 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." + ) + + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + # Expand weights using model-specific scaling function + # e.g. [{...}, 7] -> [{expanded dict...}, 7] + scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[ + self.config._class_name + ] # updated to use WanTransformer3DModel + weights = scale_expansion_fn(self, weights) + set_weights_and_activate_adapters(self, adapter_names, weights) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + rotary_emb: torch.Tensor, + temb: torch.Tensor, + timestep_proj: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Forward pass of the 3D WAN Transformer. + + This method implements the complete forward pass including: + 1. Patch embedding of input + 2. Rotary embedding preparation + 3. Cross-attention with encoder states + 4. Transformer block processing + 5. Output normalization and projection + + Args: + hidden_states (torch.Tensor): Input tensor to transform + encoder_hidden_states (torch.Tensor): Cross-attention encoder states + rotary_emb (torch.Tensor): Rotary position embeddings + temb (torch.Tensor): Time embedding for diffusion process + timestep_proj (torch.Tensor): Projected timestep embeddings + encoder_hidden_states_image (Optional[torch.Tensor]): Image encoder states for I2V + return_dict (bool): Whether to return a dictionary or tuple + attention_kwargs (Optional[Dict[str, Any]]): Additional attention arguments + + Returns: + Union[torch.Tensor, Dict[str, torch.Tensor]]: + Transformed hidden states, either as tensor or in a dictionary + """ + # Prepare rotary embeddings by splitting along batch dimension + rotary_emb = torch.split(rotary_emb, 1, dim=0) + + # Apply patch embedding and reshape for transformer processing + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) + + # Concatenate image and text encoder states if image conditioning is present + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # Standard forward pass + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Output normalization, projection & unpatchify + if temb.ndim == 3: + # Handle 3D time embeddings: batch_size, seq_len, inner_dim (WAN 2.2 T2V) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # Handle 2D time embeddings: batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Ensure tensors are on the same device as hidden_states + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + # Apply adaptive layer normalization with time conditioning + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + + # Final output projection + hidden_states = self.proj_out(hidden_states) + + # Store output for return (compiler optimization) + output = hidden_states + + # Return in requested format + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/pipelines/configs/wan_config.json b/QEfficient/diffusers/pipelines/configs/wan_config.json new file mode 100644 index 000000000..3f5edce07 --- /dev/null +++ b/QEfficient/diffusers/pipelines/configs/wan_config.json @@ -0,0 +1,36 @@ +{ + "description": "Default configuration for Wan pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)", + "modules": { + "transformer": { + "specializations": [ + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 1 + }, + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 2 + } + ], + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 16, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null + } + } + } +} \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 6d9243fdc..19e7701d4 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from diffusers.models.transformers.transformer_wan import WanTransformerBlock from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform @@ -360,8 +361,6 @@ def __init__(self, model: nn.Module) -> None: Args: model (nn.Module): The Flux transformer model to wrap - use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions - for better modularity and potential optimization """ super().__init__(model) @@ -450,13 +449,18 @@ def export( dynamic_axes (Dict): Specification of dynamic dimensions export_dir (str, optional): Directory to save ONNX model export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) + use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions + for better modularity and potential optimization Returns: str: Path to the exported ONNX model """ if use_onnx_subfunctions: - export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}} + export_kwargs = { + "export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}, + "use_onnx_subfunctions": True, + } # Sort _use_default_values in config to ensure consistent hash generation during export self.model.config["_use_default_values"].sort() @@ -479,3 +483,150 @@ def compile(self, specializations: List[Dict], **compiler_options) -> None: **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) """ self._compile(specializations=specializations, **compiler_options) + + +class QEffWanUnifiedTransformer(QEFFBaseModel): + """ + Wrapper for WAN Unified Transformer with ONNX export and QAIC compilation capabilities. + + This class handles the unified WAN transformer model that combines high and low noise transformers + into a single model for efficient deployment. Based on the timestep shape, the model dynamically + selects between high and low noise transformers during inference. + + The wrapper applies specific transformations and optimizations for efficient inference on + Qualcomm AI hardware, particularly for video diffusion models. + + Attributes: + model (nn.Module): The QEffWanUnifiedWrapper model that combines high/low noise transformers + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__(self, unified_transformer): + """ + Initialize the Wan unified transformer. + + Args: + model (nn.Module): Wan unified transformer model + """ + super().__init__(unified_transformer) + self.model = unified_transformer + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying Wan transformer model + """ + return self.model.config.__dict__ + + def get_onnx_params(self): + """ + Generate ONNX export configuration for the Wan transformer. + + Creates example inputs for all Wan-specific inputs including hidden states, + text embeddings, timestep conditioning, + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + batch_size = constants.WAN_ONNX_EXPORT_BATCH_SIZE + example_inputs = { + # hidden_states = [ bs, in_channels, frames, latent_height, latent_width] + "hidden_states": torch.randn( + batch_size, + self.model.config.in_channels, + constants.WAN_ONNX_EXPORT_LATENT_FRAMES, + constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P, + constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P, + dtype=torch.float32, + ), + # encoder_hidden_states = [BS, seq len , text dim] + "encoder_hidden_states": torch.randn( + batch_size, constants.WAN_ONNX_EXPORT_SEQ_LEN, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32 + ), + # Rotary position embeddings: [2, context_length, 1, rotary_dim]; 2 is from tuple of cos, sin freqs + "rotary_emb": torch.randn( + 2, constants.WAN_ONNX_EXPORT_CL_180P, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32 + ), + # Timestep embeddings: [batch_size=1, embedding_dim] + "temb": torch.randn(batch_size, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32), + # Projected timestep embeddings: [batch_size=1, projection_dim, embedding_dim] + "timestep_proj": torch.randn( + batch_size, + constants.WAN_PROJECTION_DIM, + constants.WAN_TEXT_EMBED_DIM, + dtype=torch.float32, + ), + # Timestep parameter: Controls high/low noise transformer selection based on shape + "tsp": torch.ones(1, dtype=torch.int64), + } + + output_names = ["output"] + + dynamic_axes = { + "hidden_states": { + 0: "batch_size", + 1: "num_channels", + 2: "num_frames", + 3: "latent_height", + 4: "latent_width", + }, + "timestep": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "rotary_emb": {1: "cl"}, + "tsp": {0: "model_type"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + use_onnx_subfunctions: bool = False, + ) -> str: + """Export the Wan transformer model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) + use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions + for better modularity and potential optimization + Returns: + str: Path to the exported ONNX model + """ + if use_onnx_subfunctions: + export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}, "use_onnx_subfunctions": True} + + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=True, + **export_kwargs, + ) + + def compile(self, specializations, **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py index 24eb36f53..4bb305447 100644 --- a/QEfficient/diffusers/pipelines/pipeline_utils.py +++ b/QEfficient/diffusers/pipelines/pipeline_utils.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import math import os from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass @@ -12,6 +13,8 @@ import numpy as np import PIL.Image +import torch +import torch.nn as nn from tqdm import tqdm from QEfficient.utils._utils import load_json @@ -36,6 +39,53 @@ def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_fac return cl, latent_height, latent_width +def calculate_latent_dimensions_with_frames( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int, + vae_scale_factor_temporal: int, + patch_height: int, + patch_width: int, +) -> int: + """ + Calculate the latent dimensions for video generation models. + + This method computes the compressed sequence length (cl), + Latent height, Latent width , Latent frames based on the + target video dimensions, VAE scale factors, and patch sizes. + + Args: + height (int): Target video height in pixels + width (int): Target video width in pixels + num_frames (int): Target video frames in pixels + vae_scale_factor_spatial (int): spatial vae_scale_factor from model config + vae_scale_factor_temporal (int): temporal vae_scale_factor from model config + patch_height (int): patch_height from model config + patch_width (int): patch_width from model config + + Returns: + tuple: (cl, latent_height, latent_width) + - cl (int): Compressed latent dimension for transformer input + - latent_height (int): Height in latent space + - latent_width (int): Width in latent space + - latent_frames (int): frames in latent space + + Mathematical Formula: + latent_height = height // vae_scale_factor_spatial + latent_width = width // vae_scale_factor_spatial + latent_frames = math.ceil(num_frames / vae_scale_factor_temporal) + cl = (latent_height // patch_height) * (latent_width // patch_width) * latent_frames + + """ + # Calculate latent space dimensions after VAE encoding + latent_height = height // vae_scale_factor_spatial + latent_width = width // vae_scale_factor_spatial + latent_frames = math.ceil(num_frames / vae_scale_factor_temporal) + cl = (latent_height // patch_height * latent_width // patch_width) * latent_frames + return cl, latent_height, latent_width, latent_frames + + def config_manager(cls, config_source: Optional[str] = None): """ JSON-based compilation configuration manager for diffusion pipelines. @@ -92,10 +142,19 @@ def _prepare_and_compile(module_name: str, module_obj: Any) -> None: specializations = config["modules"][module_name]["specializations"].copy() compile_kwargs = config["modules"][module_name]["compilation"] - if specialization_updates and module_name in specialization_updates: - specializations.update(specialization_updates[module_name]) - - module_obj.compile(specializations=[specializations], **compile_kwargs) + if ( + specialization_updates and module_name in specialization_updates + ): # Apply specialization updates if available + if isinstance(specializations, list): # for unified models spec will be [{high_noise}, {low_noise}] + for i, spec in enumerate(specializations): + spec.update(specialization_updates[module_name][i]) + else: + specializations.update(specialization_updates[module_name]) + specializations = [specializations] + else: + specializations = [specializations] + # Compile with prepared specializations + module_obj.compile(specializations=specializations, **compile_kwargs) # Execute compilations in parallel with ThreadPoolExecutor(max_workers=len(modules)) as executor: @@ -134,12 +193,19 @@ def compile_modules_sequential( specializations = module_config[module_name]["specializations"].copy() compile_kwargs = module_config[module_name]["compilation"] - # Apply dynamic specialization updates if provided - if specialization_updates and module_name in specialization_updates: - specializations.update(specialization_updates[module_name]) - - # Compile the module to QPC format - module_obj.compile(specializations=[specializations], **compile_kwargs) + if ( + specialization_updates and module_name in specialization_updates + ): # Apply specialization updates if available + if isinstance(specializations, list): # for unified models spec will be [{high_noise}, {low_noise}] + for i, spec in enumerate(specializations): + spec.update(specialization_updates[module_name][i]) + else: + specializations.update(specialization_updates[module_name]) + specializations = [specializations] + else: + specializations = [specializations] + # Compile with prepared specializations + module_obj.compile(specializations=specializations, **compile_kwargs) @dataclass(frozen=True) @@ -216,3 +282,69 @@ def __repr__(self): # List of module name that require special handling during export # when use_onnx_subfunctions is enabled ONNX_SUBFUNCTION_MODULE = ["transformer"] + + +class QEffWanUnifiedWrapper(nn.Module): + """ + A wrapper class that combines WAN high and low noise transformers into a single unified transformer. + + This wrapper dynamically selects between high and low noise transformers based on the timestep shape + in the ONNX graph during inference. This approach enables efficient deployment of both transformer + variants in a single model. + + Attributes: + transformer_high(nn.Module): The high noise transformer component + transformer_low(nn.Module): The low noise transformer component + config: Configuration shared between both transformers (from high noise transformer) + """ + + def __init__(self, transformer_high, transformer_low): + super().__init__() + self.transformer_high = transformer_high + self.transformer_low = transformer_low + # Both high and low noise transformers share the same configuration + self.config = transformer_high.config + + def forward( + self, + hidden_states, + encoder_hidden_states, + rotary_emb, + temb, + timestep_proj, + tsp, + attention_kwargs=None, + return_dict=False, + ): + # Condition based on timestep shape + is_high_noise = tsp.shape[0] == torch.tensor(1) + + high_hs = hidden_states.detach() + ehs = encoder_hidden_states.detach() + rhs = rotary_emb.detach() + ths = temb.detach() + projhs = timestep_proj.detach() + + noise_pred_high = self.transformer_high( + hidden_states=high_hs, + encoder_hidden_states=ehs, + rotary_emb=rhs, + temb=ths, + timestep_proj=projhs, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + + noise_pred_low = self.transformer_low( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + temb=temb, + timestep_proj=timestep_proj, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + + # Select based on timestep condition + noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low) + return noise_pred diff --git a/QEfficient/diffusers/pipelines/wan/__init__.py b/QEfficient/diffusers/pipelines/wan/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/wan/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py new file mode 100644 index 000000000..edae438ae --- /dev/null +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -0,0 +1,758 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +""" +QEfficient WAN Pipeline Implementation + +This module provides an optimized implementation of the WAN pipeline +for high-performance text-to-video generation on Qualcomm AI hardware. +The pipeline supports WAN 2.2 architectures with unified transformer. + +TODO: 1. Update Vae, umt5 to Qaic; present running on cpu +""" + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import WanPipeline + +from QEfficient.diffusers.pipelines.pipeline_module import QEffWanUnifiedTransformer +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ONNX_SUBFUNCTION_MODULE, + ModulePerf, + QEffPipelineOutput, + QEffWanUnifiedWrapper, + calculate_latent_dimensions_with_frames, + compile_modules_parallel, + compile_modules_sequential, + config_manager, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants +from QEfficient.utils.logging_utils import logger + + +class QEffWanPipeline: + """ + QEfficient-optimized WAN pipeline for high-performance text-to-video generation on Qualcomm AI hardware. + + This pipeline provides an optimized implementation of the WAN diffusion model + specifically designed for deployment on Qualcomm AI Cloud (QAIC) devices. It extends the original + HuggingFace WAN model with QEfficient-optimized components that can be exported to ONNX format + and compiled into Qualcomm Program Container (QPC) files for efficient video generation. + + The pipeline supports the complete WAN workflow including: + - UMT5 text encoding for rich semantic understanding + - Unified transformer architecture: Combines multiple transformer stages into a single optimized model + - VAE decoding for final video output + - Performance monitoring and hardware optimization + + Attributes: + text_encoder: UMT5 text encoder for semantic text understanding (TODO: QEfficient optimization) + unified_wrapper (QEffWanUnifiedWrapper): Wrapper combining transformer stages + transformer (QEffWanUnifiedTransformer): Optimized unified transformer for denoising + vae_decode: VAE decoder for latent-to-video conversion + modules (Dict[str, Any]): Dictionary of pipeline modules for batch operations + model (WanPipeline): Original HuggingFace WAN model reference + tokenizer: Text tokenizer for preprocessing + scheduler: Diffusion scheduler for timestep management + + Example: + >>> from QEfficient.diffusers.pipelines.wan import QEffWanPipeline + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> videos = pipeline( + ... prompt="A cat playing in a garden", + ... height=480, + ... width=832, + ... num_frames=81, + ... num_inference_steps=4 + ... ) + >>> # Save generated video + >>> videos.images[0].save("generated_video.mp4") + """ + + _hf_auto_class = WanPipeline + + def __init__(self, model, **kwargs): + """ + Initialize the QEfficient WAN pipeline. + + This pipeline provides an optimized implementation of the WAN text-to-video model + for deployment on Qualcomm AI hardware. It wraps the original HuggingFace WAN model + components with QEfficient-optimized versions that can be exported to ONNX and compiled + for QAIC devices. + + Args: + model: Pre-loaded WanPipeline model with transformer and transformer_2 components + **kwargs: Additional keyword arguments including configuration parameters + """ + # Store original model and configuration + self.model = model + self.kwargs = kwargs + self.custom_config = None + + # Text encoder (TODO: Replace with QEfficient UMT5 optimization) + self.text_encoder = model.text_encoder + + # Create unified transformer wrapper combining dual-stage models(high, low noise DiTs) + self.unified_wrapper = QEffWanUnifiedWrapper(model.transformer, model.transformer_2) + self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper) + + # VAE decoder for latent-to-video conversion + self.vae_decode = model.vae + + # Store all modules in a dictionary for easy iteration during export/compile + # TODO: add text encoder, vae decoder on QAIC + self.modules = {"transformer": self.transformer} + + # Copy tokenizers and scheduler from the original model + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.scheduler = model.scheduler + # Extract patch dimensions from transformer configuration + _, self.patch_height, self.patch_width = self.transformer.model.config.patch_size + + @property + def do_classifier_free_guidance(self): + """ + Determine if classifier-free guidance should be used. + + Returns: + bool: True if CFG should be applied based on current guidance scales + """ + return self._guidance_scale > 1.0 and (self._guidance_scale_2 is None or self._guidance_scale_2 > 1.0) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + **kwargs, + ): + """ + Load a pretrained WAN model from HuggingFace Hub or local path and wrap it with QEfficient optimizations. + + This class method provides a convenient way to instantiate a QEffWanPipeline from a pretrained + WAN model. It automatically loads the base WanPipeline model in float32 precision on CPU + and wraps all components with QEfficient-optimized versions for QAIC deployment. + + Args: + pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier + or a local path to a saved WAN model directory. Should contain transformer, transformer_2, + text_encoder, and VAE components. + **kwargs: Additional keyword arguments passed to WanPipeline.from_pretrained(). + + Returns: + QEffWanPipeline: A fully initialized pipeline instance with QEfficient-optimized components + ready for export, compilation, and inference on QAIC devices. + + Raises: + ValueError: If the model path is invalid or model cannot be loaded + OSError: If there are issues accessing the model files + RuntimeError: If model initialization fails + + Example: + >>> # Load from HuggingFace Hub + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> + >>> # Load from local path + >>> pipeline = QEffWanPipeline.from_pretrained("/local/path/to/wan") + >>> + >>> # Load with custom cache directory + >>> pipeline = QEffWanPipeline.from_pretrained( + ... "wan-model-id", + ... cache_dir="/custom/cache/dir" + ... ) + """ + # Load the base WAN model in float32 on CPU for optimization + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + device_map="cpu", + **kwargs, + ) + return cls( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) + + def export( + self, + export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Export all pipeline modules to ONNX format for deployment preparation. + + This method systematically exports the unified transformer to ONNX format with + video-specific configurations including temporal dimensions, dynamic axes, and + optimization settings. The export process prepares the model for subsequent + compilation to QPC format for efficient inference on QAIC hardware. + + Args: + export_dir (str, optional): Target directory for saving ONNX model files. If None, + uses the default export directory structure. The directory will be created + if it doesn't exist. + use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction + optimization for supported modules. This can optimize the graph structure + and improve compilation efficiency for complex models like the transformer. + + Returns: + str: Absolute path to the export directory containing all ONNX model files. + + Raises: + RuntimeError: If ONNX export fails for any module + OSError: If there are issues creating the export directory or writing files + ValueError: If module configurations are invalid + + Example: + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> export_path = pipeline.export( + ... export_dir="/path/to/export", + ... use_onnx_subfunctions=True + ... ) + """ + + # Export each module with video-specific parameters + for module_name, module_obj in self.modules.items(): + # Get ONNX export configuration with video dimensions + example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params() + + # Prepare export parameters + export_params = { + "inputs": example_inputs, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "export_dir": export_dir, + } + + # Enable ONNX subfunctions for supported modules if requested + if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE: + export_params["use_onnx_subfunctions"] = True + + module_obj.export(**export_params) + + @staticmethod + def get_default_config_path(): + """ + Get the default configuration file path for WAN pipeline. + + Returns: + str: Path to the default WAN configuration JSON file. + """ + return os.path.join(os.path.dirname(__file__), "wan_config.json") + + def compile( + self, + compile_config: Optional[str] = None, + parallel: bool = False, + height: int = constants.WAN_ONNX_EXPORT_HEIGHT_180P, + width: int = constants.WAN_ONNX_EXPORT_WIDTH_180P, + num_frames: int = constants.WAN_ONNX_EXPORT_FRAMES, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware. + + This method takes the ONNX paths of the transformer and compiles them into an optimized format + for inference using JSON-based configuration. + + Args: + compile_config (str, optional): Path to a JSON configuration file containing + compilation settings, device mappings, and optimization parameters. If None, + uses the default configuration. + parallel (bool, default=False): Compilation mode selection: + - True: Compile modules in parallel using ThreadPoolExecutor for faster processing + - False: Compile modules sequentially for lower resource usage + height (int, default=192): Target image height in pixels. + width (int, default=320): Target image width in pixels. + num_frames (int, deafult=81) : Target num of frames in pixel space + use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX + subfunctions before compilation if not already exported. + + Raises: + RuntimeError: If compilation fails for any module or if QAIC compiler is not available + FileNotFoundError: If ONNX models haven't been exported or config file is missing + ValueError: If configuration parameters are invalid + OSError: If there are issues with file I/O during compilation + + Example: + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> # Sequential compilation with default config + >>> pipeline.compile(height=480, width=832, num_frames=81) + >>> + >>> # Parallel compilation with custom config + >>> pipeline.compile( + ... compile_config="/path/to/custom_config.json", + ... parallel=True, + ... height=480, + ... width=832, + ... num_frames=81 + ... ) + """ + # Ensure all modules are exported to ONNX before compilation + if any( + path is None + for path in [ + self.transformer.onnx_path, + ] + ): + self.export(use_onnx_subfunctions=use_onnx_subfunctions) + + # Load compilation configuration + config_manager(self, config_source=compile_config) + + # Configure pipeline dimensions and calculate compressed latent parameters + cl, latent_height, latent_width, latent_frames = calculate_latent_dimensions_with_frames( + height, + width, + num_frames, + self.model.vae.config.scale_factor_spatial, + self.model.vae.config.scale_factor_temporal, + self.patch_height, + self.patch_width, + ) + # Prepare dynamic specialization updates based on video dimensions + specialization_updates = { + "transformer": [ + # high noise + { + "cl": cl, # Compressed latent dimension + "latent_height": latent_height, # Latent space height + "latent_width": latent_width, # Latent space width + "num_frames": latent_frames, # Latent frames + }, + # low noise + { + "cl": cl, # Compressed latent dimension + "latent_height": latent_height, # Latent space height + "latent_width": latent_width, # Latent space width + "num_frames": latent_frames, # Latent frames + }, + ] + } + + # Use generic utility functions for compilation + if parallel: + compile_modules_parallel(self.modules, self.custom_config, specialization_updates) + else: + compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 3.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None]]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + use_onnx_subfunctions: bool = False, + parallel_compile: bool = True, + ): + """ + Generate videos from text prompts using the QEfficient-optimized WAN pipeline on QAIC hardware. + + This is the main entry point for text-to-video generation. It orchestrates the complete WAN + diffusion pipeline optimized for Qualcomm AI Cloud devices. + + Args: + prompt (str or List[str]): Primary text prompt(s) describing the desired video content. + Required unless `prompt_embeds` is provided. + negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid + in the generated video. Used with classifier-free guidance. + height (int, optional): Target video height in pixels. Must be divisible by VAE scale factor. + Default: 480. + width (int, optional): Target video width in pixels. Must be divisible by VAE scale factor. + Default: 832. + num_frames (int, optional): Number of video frames to generate. Must satisfy temporal + divisibility requirements. Default: 81. + num_inference_steps (int, optional): Number of denoising steps. More steps generally + improve quality but increase generation time. Default: 50. + guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.0. + guidance_scale_2 (float, optional): Guidance scale for low-noise stage in WAN 2.2. + If None, uses guidance_scale value. + num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Default: 1. + generator (torch.Generator or List[torch.Generator], optional): Random generator for + reproducible generation. + latents (torch.Tensor, optional): Pre-generated latent tensors. If None, random latents + are generated based on video dimensions. + prompt_embeds (torch.Tensor, optional): Pre-computed text embeddings from UMT5 encoder. + Shape: [batch, seq_len, hidden_dim]. + negative_prompt_embeds (torch.Tensor, optional): Pre-computed negative text embeddings. + output_type (str, optional): Output format. Options: "np" (default), "pil", or "latent". + return_dict (bool, optional): Whether to return a dictionary or tuple. Default: True. + attention_kwargs (Dict[str, Any], optional): Additional attention arguments for transformer. + callback_on_step_end (Callable, optional): Callback function executed after each denoising step. + callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. + Default: ["latents"]. + max_sequence_length (int, optional): Maximum token sequence length for text encoder. Default: 512. + custom_config_path (str, optional): Path to custom JSON configuration file for compilation. + use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. + Default: False. + parallel_compile (bool, optional): Whether to compile modules in parallel. Default: True. + + Returns: + QEffPipelineOutput: A dataclass containing: + - images: Generated video(s) in the format specified by `output_type` + - pipeline_module: Performance metrics for each pipeline component + + Raises: + ValueError: If input validation fails or parameters are incompatible + RuntimeError: If compilation fails or QAIC devices are unavailable + FileNotFoundError: If custom config file is specified but not found + + Example: + >>> from QEfficient.diffusers.pipelines.wan import QEffWanPipeline + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> result = pipeline( + ... prompt="A cat playing in a sunny garden", + ... height=480, + ... width=832, + ... num_frames=81, + ... num_inference_steps=4, + ... guidance_scale=3.0 + ... ) + >>> # Save generated video + >>> result.images[0].save("cat_garden.mp4") + """ + device = "cpu" + + # Compile models with custom configuration if needed + self.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + use_onnx_subfunctions=use_onnx_subfunctions, + height=height, + width=width, + num_frames=num_frames, + ) + + # Set device IDs for all modules based on configuration + set_module_device_ids(self) + + # Step 1: Validate all inputs + self.model.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + # Ensure num_frames satisfies temporal divisibility requirements + if num_frames % self.model.vae.config.scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.model.vae.config.scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = ( + num_frames // self.model.vae.config.scale_factor_temporal * self.model.vae.config.scale_factor_temporal + + 1 + ) + num_frames = max(num_frames, 1) + + if self.model.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + # Initialize pipeline state + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 if guidance_scale_2 is not None else guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # Step 2: Determine batch size from inputs + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Step 3: Encode input prompts using UMT5 text encoder + # TODO: Update UMT5 on QAIC + prompt_embeds, negative_prompt_embeds = self.model.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Convert embeddings to transformer dtype for compatibility + transformer_dtype = self.transformer.model.transformer_high.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # Step 4: Prepare timesteps for denoising process + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Step 5: Prepare initial latent variables for video generation + num_channels_latents = self.transformer.model.config.in_channels + + latents = self.model.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # Create mask for temporal processing (used in expand_timesteps mode) + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # Step 6: Configure dual-stage processing for WAN 2.2 + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Calculate boundary timestep for stage switching in WAN 2.2 + if self.model.config.boundary_ratio is not None: + boundary_timestep = self.model.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + # Step 7: Initialize QAIC inference session for transformer + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession( + str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + ) + + # Calculate compressed latent dimension for transformer buffer allocation + cl, _, _, _ = calculate_latent_dimensions_with_frames( + height, + width, + num_frames, + self.model.vae.config.scale_factor_spatial, + self.model.vae.config.scale_factor_temporal, + self.patch_height, + self.patch_width, + ) + # Allocate output buffer for QAIC inference + output_buffer = { + "output": np.random.rand( + batch_size, + cl, # Compressed latent dimension + constants.WAN_DIT_OUT_CHANNELS, + ).astype(np.int32), + } + self.transformer.qpc_session.set_buffers(output_buffer) + transformer_perf = [] + + # Step 8: Denoising loop with dual-stage processing + with self.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self._interrupt: + continue + + self._current_timestep = t + + # Determine which model to use based on boundary timestep + if boundary_timestep is None or t >= boundary_timestep: + # High-noise stage + current_model = self.transformer.model.transformer_high + current_guidance_scale = guidance_scale + model_type = torch.ones(1, dtype=torch.int64) # High-noise model indicator + else: + # Low-noise stage + current_model = self.transformer.model.transformer_low + current_guidance_scale = guidance_scale_2 + model_type = torch.ones(2, dtype=torch.int64) # Low-noise model indicator + + # Prepare latent input with proper dtype + latent_model_input = latents.to(transformer_dtype) + + # Handle timestep expansion for temporal consistency + if self.model.config.expand_timesteps: + # Expand timesteps spatially for better temporal modeling + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # Standard timestep broadcasting + timestep = t.expand(latents.shape[0]) + + # Extract dimensions for patch processing + batch_size, num_channels, num_frames, height, width = latents.shape + p_t, p_h, p_w = current_model.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Generate rotary position embeddings + rotary_emb = current_model.rope(latent_model_input) + rotary_emb = torch.cat(rotary_emb, dim=0) + ts_seq_len = None + timestep = timestep.flatten() + + # Generate conditioning embeddings (time + text) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + current_model.condition_embedder( + timestep, prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len + ) + ) + + # Generate negative conditioning for classifier-free guidance + if self.do_classifier_free_guidance: + temb, timestep_proj, encoder_hidden_states_neg, encoder_hidden_states_image = ( + current_model.condition_embedder( + timestep, + negative_prompt_embeds, + encoder_hidden_states_image=None, + timestep_seq_len=ts_seq_len, + ) + ) + + # Reshape timestep projection for transformer input + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # Prepare inputs for QAIC inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy(), + "tsp": model_type.detach().numpy(), # Transformer stage pointer + } + + # Prepare negative inputs for classifier-free guidance + if self.do_classifier_free_guidance: + inputs_aic2 = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states_neg.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy(), + } + + # Run conditional prediction with caching context + with current_model.cache_context("cond"): + # QAIC inference for conditional prediction + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + print(f"DIT {i} time {end_transformer_step_time - start_transformer_step_time:.2f} seconds") + + # Process transformer output + hidden_states = torch.tensor(outputs["output"]) + + # Reshape output from patches back to video format + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + + # Permute dimensions to reconstruct video tensor + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + # Run unconditional prediction for classifier-free guidance + if self.do_classifier_free_guidance: # Note: CFG is False for WAN Lightning + with current_model.cache_context("uncond"): + # QAIC inference for unconditional prediction + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic2) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + # Process unconditional output + hidden_states = torch.tensor(outputs["output"]) + + # Reshape unconditional output + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_uncond = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + # Apply classifier-free guidance + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # Update latents using scheduler (x_t -> x_t-1) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Execute callback if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + # Step 9: Decode latents to video + if not output_type == "latent": + # Prepare latents for VAE decoding + latents = latents.to(self.vae_decode.dtype) + + # Apply VAE normalization (denormalization) + latents_mean = ( + torch.tensor(self.vae_decode.config.latents_mean) + .view(1, self.vae_decode.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae_decode.config.latents_std).view( + 1, self.vae_decode.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # TODO: Enable VAE on QAIC + # VAE Decode latents to video using CPU (temporary) + video = self.model.vae.decode(latents, return_dict=False)[0] # CPU fallback + + # Post-process video for output + video = self.model.video_processor.postprocess_video(video.detach()) + else: + video = latents + + # Step 10: Collect performance metrics + perf_data = { + "transformer": transformer_perf, # Unified transformer (QAIC) + } + + # Build performance metrics for output + perf_metrics = [ModulePerf(module_name=name, perf=perf_data[name]) for name in perf_data.keys()] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=video, + ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 925fb8b8f..d0318ac3e 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -88,7 +88,7 @@ def get_models_dir(): SIZE_THRESHOLD_DEFAULT = 1024 -COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] +COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-compile-only"] DEFAULT_AIC_HW_VERSION = "ai100" ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 @@ -152,6 +152,22 @@ def get_models_dir(): FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM +# Wan Transformer Constants +WAN_TEXT_EMBED_DIM = 5120 +WAN_PROJECTION_DIM = 6 +WAN_ONNX_EXPORT_BATCH_SIZE = 1 +WAN_ONNX_EXPORT_FRAMES = 81 +WAN_ONNX_EXPORT_LATENT_FRAMES = 21 +WAN_ONNX_EXPORT_SEQ_LEN = 512 +WAN_ONNX_EXPORT_ROTARY_DIM = 128 +WAN_DIT_OUT_CHANNELS = 64 +# Wan dims for 180p +WAN_ONNX_EXPORT_CL_180P = 5040 +WAN_ONNX_EXPORT_LATENT_HEIGHT_180P = 24 +WAN_ONNX_EXPORT_LATENT_WIDTH_180P = 40 +WAN_ONNX_EXPORT_HEIGHT_180P = 192 +WAN_ONNX_EXPORT_WIDTH_180P = 320 + # For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length CCL_START_MAP = { 32768: (4096, 4000), diff --git a/examples/diffusers/wan/README.md b/examples/diffusers/wan/README.md new file mode 100644 index 000000000..b90bf3908 --- /dev/null +++ b/examples/diffusers/wan/README.md @@ -0,0 +1,249 @@ +# WAN 2.2 Text-to-Video Generation Examples + +This directory contains examples demonstrating how to use the QEffWanPipeline to generate videos using the WAN 2.2 text-to-video model with Lightning LoRA optimization. + +## Overview + +WAN 2.2 is a text-to-video diffusion model that uses dual-stage processing for high-quality video generation. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient video generation with Lightning LoRA for fast 4-step inference. + +## Files + +- **`wan_lightning.py`** - Complete example with Lightning LoRA for fast video generation +- **`wan_config.json`** - Configuration file for transformer module compilation + +## Quick Start + +### Basic Usage + +The simplest way to generate videos with WAN 2.2 Lightning: +### 1. Load Model +```python +from QEfficient import QEffWanPipeline +import torch +from diffusers.utils import export_to_video + +# Initialize pipeline +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") +``` + +### 2. Lightning LoRA Integration + +Load high and low noise LoRA adapters for fast 4-step generation: + +```python +from huggingface_hub import hf_hub_download +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +import safetensors.torch + +# Download Lightning LoRAs +high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +) +low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +) + +# Load and apply LoRAs +def load_wan_lora(path: str): + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + +pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +) +pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + +pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +) +pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) +``` + + +### 3. Compile API + +To compile the model for desired resolution: + +```python +# Compile with custom configuration +pipeline.compile( + compile_config="examples/diffusers/wan/wan_config.json", + parallel=True, + height=480, + width=832, + num_frames=81, + use_onnx_subfunctions=False, +) +``` + +### 4. Generate video +```python +output = pipeline( + prompt="A cat playing in a sunny garden", + num_frames=81, + height=480, + width=832, + guidance_scale=1.0, + num_inference_steps=4, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Export video +frames = output.images[0] +export_to_video(frames, "cat_garden.mp4", fps=16) +``` + +Run the Lightning example: +```bash +python wan_lightning.py +``` + +## Advanced Customization + + +### 1. Reduce Model Layers for Faster Inference + + +```python +# Reduce to 2 layers for faster inference +pipeline.transformer.model.transformer_high.config.num_layers = 2 +pipeline.transformer.model.transformer_low.config.num_layers = 2 + +original_blocks = pipeline.transformer.model.transformer_high.blocks +org_blocks = pipeline.transformer.model.transformer_low.blocks + +pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config.num_layers)] +) +pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( + [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config.num_layers)] +) +``` + +### 2. To Run with Blocking + +Use environment variables to enable attention blocking: + +```bash +# For 180p Generation (192x320) with HKV blocking +ATTENTION_BLOCKING_MODE=kv head_block_size=16 num_kv_blocks=3 python wan_lightning.py + +# For 480p Generation (480x832) with HQKV blocking +ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=21 num_q_blocks=2 python wan_lightning.py + +# for 720P Generation (720x1280) with HQKV blocking +ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=48 num_q_blocks=5 python wan_lightning.py +``` + +### Blocking Modes + +Head blocking is common in all modes + +- **`kv`**: Block key-value processing (along with Head blocking) +- **`q`**: Block query processing (along with Head blocking) +- **`qkv`**: Block query, key, and value (along with Head blocking) +- **`default`**: Head-only blocking + + +## Configuration File + +The `wan_config.json` file controls compilation settings for the transformer module: + +### Module Structure + +The configuration includes dual specializations for WAN's high and low noise models: + +```json +{ + "transformer": { + "specializations":[ + { + "batch_size":"1", + "cl":"5040", + "latent_height":"24", + "latent_width":"40", + "model_type":"1", + "num_channels":"16", + "num_frames":"21", + "sequence_length":"512", + "steps":"1" + }, + { + "batch_size":"1", + "cl":"5040", + "latent_height":"24", + "latent_width":"40", + "model_type":"2", + "num_channels":"16", + "num_frames":"21", + "sequence_length":"512", + "steps":"1" + } + ] +} +} +``` + +### Configuration Parameters + +#### Specializations +- `batch_size`: Batch size for inference +- `num_channels`: Number of latent channels (16 for WAN) +- `num_frames`: Number of latent frames (21 for 81 input frames) +- `latent_height`/`latent_width`: Latent space dimensions +- `cl`: Compressed latent dimension for transformer +- `sequence_length` : Sequence length of text encoder 512 +- `model_type`: 1 for high noise model, 2 for low noise model + +#### Compilation +- `mdp_ts_num_devices`: Number of devices for model parallelism (16 recommended) +- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication +- `convert_to_fp16`: Convert model to FP16 precision +- `aic_num_cores`: Number of AI cores to use (16 recommended) +- `mos`: Degree of weight splitting done across cores (1 is recommended) +- `mdts_mos`: Degree of weight splitting done across multi-device tensor slices (1 is recommended) + +## Key Parameters + +### Generation Parameters + +- **`prompt`** (str): Text description of the video to generate +- **`num_frames`** (int): Number of video frames (default: 81) +- **`height`** (int): Output video height in pixels (default: 480) +- **`width`** (int): Output video width in pixels (default: 832) +- **`guidance_scale`** (float): Guidance scale for high noise stage (1.0 for Lightning) +- **`guidance_scale_2`** (float): Guidance scale for low noise stage (1.0 for Lightning) +- **`num_inference_steps`** (int): Number of denoising steps (4 for Lightning) +- **`generator`** (torch.Generator): Random seed for reproducibility +- **`parallel_compile`** (bool): Enable parallel compilation of modules +- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export + + +## Output + +The pipeline returns an output object containing: +- `images`: List of video frames as PIL Image objects +- Performance metrics (timing information) + +Example output: +```python +print(output) # Displays performance information +frames = output.images[0] # Access the generated video frames +export_to_video(frames, "output.mp4", fps=16) # Export to MP4 +``` + +## Notes + +- WAN 2.2 Lightning is optimized for 4-step generation with `guidance_scale=1.0` +- The transformer uses dual-stage processing (high/low noise models) +- Attention blocking is essential for higher resolutions (480p+) + + +## References + +- [WAN 2.2 Model Card](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) +- [Lightning LoRA](https://huggingface.co/lightx2v/Wan2.2-Lightning) +- [QEfficient Documentation](../../../README.md) diff --git a/examples/diffusers/wan/wan_config.json b/examples/diffusers/wan/wan_config.json new file mode 100644 index 000000000..7e752ba14 --- /dev/null +++ b/examples/diffusers/wan/wan_config.json @@ -0,0 +1,37 @@ +{ + "description": "Default configuration for Wan pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)", + "model_type": "wan", + "modules": { + "transformer": { + "specializations": [ + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 1 + }, + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 2 + } + ], + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 16, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null + } + } + } +} \ No newline at end of file diff --git a/examples/diffusers/wan/wan_lightning.py b/examples/diffusers/wan/wan_lightning.py new file mode 100644 index 000000000..691da651f --- /dev/null +++ b/examples/diffusers/wan/wan_lightning.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import safetensors.torch +import torch +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline + +# Load the pipeline +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + +# Download the LoRAs +high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +) +low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +) + + +# LoRA conversion +def load_wan_lora(path: str): + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + +# Load into the transformers +pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +) +pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) +pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +) +pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + + +prompt = "In a warmly lit living room, an elderly man with gray hair sits in a wooden armchair adorned with a blue cushion. He wears a gray cardigan over a white shirt, engrossed in reading a book. As he turns the pages, he subtly adjusts his posture, ensuring his glasses stay in place. He then removes his glasses, holding them in his hand, and turns his head to the right, maintaining his grip on the book. The soft glow of a bedside lamp bathes the scene, creating a calm and serene atmosphere, with gentle shadows enhancing the intimate setting." + +output = pipeline( + prompt=prompt, + num_frames=81, + guidance_scale=1.0, + guidance_scale_2=1.0, + num_inference_steps=4, + generator=torch.manual_seed(0), + custom_config_path="examples/diffusers/wan/wan_config.json", + height=480, + width=832, + use_onnx_subfunctions=True, + parallel_compile=True, +) +frames = output.images[0] +export_to_video(frames, "output_t2v.mp4", fps=16) +print(output) diff --git a/examples/diffusers/wan/wan_lightning_custom.py b/examples/diffusers/wan/wan_lightning_custom.py new file mode 100644 index 000000000..a60d57bb6 --- /dev/null +++ b/examples/diffusers/wan/wan_lightning_custom.py @@ -0,0 +1,162 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Wan2.2-Lightning Custom Configuration Example + +This example demonstrates how to customize the Wan2.2-Lightning model with various options: +1. Custom video dimensions (height/width) and frame count +2. Custom scheduler configuration +3. Reduced model layers for faster inference +4. Custom compilation settings +5. Custom runtime configuration via JSON config file +6. LoRA adapter loading and configuration + +Use this example to learn how to tune Wan2.2-Lightning for your specific video generation needs. +""" + +import safetensors.torch +import torch +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline + +# ============================================================================ +# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS +# ============================================================================ + +# Option 1: Basic initialization with default parameters +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + +# ============================================================================ +# LORA ADAPTER LOADING FOR LIGHTNING MODEL +# ============================================================================ +# Download and load Lightning LoRA adapters for faster inference + +# Download the LoRAs from Hugging Face Hub +high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +) +low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +) + + +# LoRA conversion utility function +def load_wan_lora(path: str): + """Convert and load WAN LoRA weights from safetensors format.""" + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + +# Load LoRA adapters into the high and low noise transformers +pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +) +pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + +pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +) +pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + +# ============================================================================ +# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION +# ============================================================================ +# Uncomment to use a custom scheduler (e.g., different sampling methods): +# +# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Reduce the number of transformer blocks to speed up video generation. +# +# Trade-off: Faster inference but potentially lower video quality +# Use case: Quick testing, prototyping, or when speed is critical +# +# Uncomment the following lines to use only a subset of transformer layers: +# +# # Configure for 2-layer model (faster inference) +# pipeline.transformer.model.transformer_high.config.num_layers = 1 +# pipeline.transformer.model.transformer_low.config.num_layers = 1 +# +# # Reduce high noise transformer blocks +# original_blocks = pipeline.transformer.model.transformer_high.blocks +# pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( +# [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config.num_layers)] +# ) +# +# # Reduce low noise transformer blocks +# org_blocks = pipeline.transformer.model.transformer_low.blocks +# pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( +# [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config.num_layers)] +# ) + +# ============================================================================ +# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION +# ============================================================================ +# Pre-compile the model for optimized performance on target hardware. +# +# When to use: +# - When you want to compile the model separately before generation +# - When you need to skip video generation and only prepare the model +# +# NOTE-1: If compile_config is not specified, the default configuration from +# QEfficient/diffusers/pipelines/wan/wan_config.json will be used +# +# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations +# This feature improves export performance by breaking down the model into smaller, +# more manageable ONNX functions, which can lead to improved compile time. +# +# Uncomment to compile with a custom configuration: +# pipeline.compile( +# compile_config="examples/diffusers/wan/wan_config.json", +# parallel=True, +# height=480, +# width=832, +# num_frames=81, +# use_onnx_subfunctions=True +# ) + +# ============================================================================ +# VIDEO GENERATION WITH CUSTOM RUNTIME CONFIGURATION +# ============================================================================ +# Generate a video using the configured pipeline. +# +# Note: Use of custom_config_path provides flexibility to set device_ids for each +# module, so you can skip the separate pipeline.compile() step. + +# Custom prompt for video generation +prompt = "A cat wearing a hat walking through a magical forest with glowing mushrooms and fireflies dancing around, cinematic lighting, high quality" + +# Alternative video dimensions for different use cases, corresponding default blocking +# height=192, width=320 # ATTENTION_BLOCKING_MODE=kv head_block_size=16 num_kv_blocks=3 python3 examples/diffusers/wan/wan_lightning.py +# height=480, width=832 # ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=21 num_q_blocks=2 python3 examples/diffusers/wan/wan_lightning.py +# height=720, width=1280 # ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=48 num_q_blocks=5 python3 examples/diffusers/wan/wan_lightning.py + +output = pipeline( + prompt=prompt, + num_frames=81, # Number of video frames to generate + guidance_scale=1.0, # Primary guidance scale + guidance_scale_2=1.0, # Secondary guidance scale for dual guidance + num_inference_steps=4, # Lightning model uses fewer steps + generator=torch.manual_seed(42), # For reproducible results + custom_config_path="examples/diffusers/wan/wan_config.json", + height=480, + width=832, + use_onnx_subfunctions=True, # Enable ONNX optimizations + parallel_compile=False, # Set to True for parallel compilation +) + +# Extract generated frames and export to video +frames = output.images[0] +export_to_video(frames, "custom_wan_lightning_output.mp4", fps=16) +print(output) diff --git a/pyproject.toml b/pyproject.toml index fe0c42ec2..9da98f71d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,9 @@ dependencies = [ "fire", "py7zr", "torchmetrics==1.7.0", + "ftfy==6.3.1", + "imageio==2.37.2", + "imageio-ffmpeg==0.6.0", "torch==2.7.0; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", @@ -53,7 +56,6 @@ dependencies = [ test = ["pytest","pytest-mock"] docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"] quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"] - [build-system] requires = ["setuptools>=62.0.0"] build-backend = "setuptools.build_meta" @@ -75,3 +77,16 @@ target-version = "py310" addopts = "-W ignore -s -v" junit_logging = "all" doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS" +markers = [ + "on_qaic: marks tests as requiring QAIC hardware", + "diffusion_models: marks tests for diffusion models", + "wan: marks tests for WAN model", + "flux: marks tests for Flux model", + "regular: marks regular tests", + "nightly: marks nightly tests", + "multimodal: marks multimodal tests", + "qnn: marks QNN tests", + "cli: marks CLI tests", + "finetune: marks finetune tests", + "vllm: marks vLLM tests" +] diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index e9925dee2..3420c025b 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -22,7 +22,6 @@ pipeline { . preflight_qeff/bin/activate && pip install --upgrade pip setuptools && pip install .[test] && - pip install .[diffusers] && pip install junitparser pytest-xdist && pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs @@ -59,7 +58,7 @@ pipeline { mkdir -p $PWD/Non_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_qaic && - pytest tests -m '(not cli) and (on_qaic) and (not nightly) and (not multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log2.xml && + pytest tests -m '(not cli) and (on_qaic) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log2.xml && junitparser merge tests/tests_log2.xml tests/tests_log.xml && deactivate" ''' @@ -68,7 +67,7 @@ pipeline { } } } - stage('QAIC MultiModal Tests') { + stage('QAIC MultiModal Tests') { steps { timeout(time: 120, unit: 'MINUTES') { sh ''' @@ -78,13 +77,31 @@ pipeline { mkdir -p $PWD/Non_cli_qaic_multimodal && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic_multimodal && - pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log6.xml && + pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log6.xml && junitparser merge tests/tests_log6.xml tests/tests_log.xml && deactivate" ''' } } } + stage('QAIC Diffusion Models Tests') { + steps { + timeout(time: 120, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Non_cli_qaic_diffusion && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Non_cli_qaic_diffusion && + export HF_HUB_CACHE=/huggingface_hub && + pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log_diffusion.xml && + junitparser merge tests/tests_log_diffusion.xml tests/tests_log.xml && + deactivate" + ''' + } + } + } stage('Inference Tests') { steps { timeout(time: 120, unit: 'MINUTES') { diff --git a/tests/diffusers/diffusers_utils.py b/tests/diffusers/diffusers_utils.py index 305116c03..4e407c5aa 100644 --- a/tests/diffusers/diffusers_utils.py +++ b/tests/diffusers/diffusers_utils.py @@ -69,8 +69,7 @@ def check_file_exists(file_path: str, file_type: str = "file") -> bool: bool: True if file exists """ exists = os.path.exists(file_path) - status = "āœ…" if exists else "āŒ" - print(f"{status} {file_type}: {file_path}") + print(f"file exist: {exists}; {file_type}: {file_path}") return exists @staticmethod @@ -161,7 +160,7 @@ def validate_module_mad( # Always report MAD value step_str = f" {step_info}" if step_info else "" - print(f"šŸ” {module_name.upper()} MAD{step_str}: {mad_value:.8f}") + print(f"{module_name.upper()} MAD{step_str}: {mad_value:.8f}") # Always validate - fail if tolerance exceeded tolerance = self.tolerances.get(module_name, 1e-2) diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json index 7d0c17d55..9f13daca0 100644 --- a/tests/diffusers/flux_test_config.json +++ b/tests/diffusers/flux_test_config.json @@ -1,33 +1,33 @@ { "model_setup": { - "height": 256, - "width": 256, - "num_transformer_layers": 2, - "num_single_layers": 2, - "use_onnx_subfunctions": false - }, + "height": 256, + "width": 256, + "num_transformer_layers": 2, + "num_single_layers": 2, + "use_onnx_subfunctions": false + }, "mad_validation": { - "tolerances": { - "clip_text_encoder": 0.1, - "t5_text_encoder": 5.5, - "transformer": 2.0, - "vae_decoder": 1.0 - } + "tolerances": { + "clip_text_encoder": 0.1, + "t5_text_encoder": 5.5, + "transformer": 2.0, + "vae_decoder": 1.0 + } }, "pipeline_params": { - "test_prompt": "A cat holding a sign that says hello world", - "num_inference_steps": 2, - "guidance_scale": 0.0, - "max_sequence_length": 256, - "validate_gen_img": true, - "min_image_variance": 1.0, - "custom_config_path": null - }, + "test_prompt": "A cat holding a sign that says hello world", + "num_inference_steps": 2, + "guidance_scale": 0.0, + "max_sequence_length": 256, + "validate_gen_img": true, + "min_image_variance": 1.0, + "custom_config_path": null + }, "validation_checks": { - "image_generation": true, - "onnx_export": true, - "compilation": true - }, + "image_generation": true, + "onnx_export": true, + "compilation": true + }, "modules": { "text_encoder": diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py index 6f4396a20..721850257 100644 --- a/tests/diffusers/test_flux.py +++ b/tests/diffusers/test_flux.py @@ -123,7 +123,7 @@ def flux_pipeline_call_with_mad_validation( pipeline.text_encoder_2.qpc_session.deactivate() # MAD Validation for Text Encoders - print("šŸ” Performing MAD validation for text encoders...") + print(" Performing MAD validation for text encoders...") mad_validator.validate_module_mad( clip_qaic_pooled_prompt_embeds, clip_torch_pooled_prompt_embeds, module_name="clip_text_encoder" ) @@ -405,7 +405,7 @@ def test_flux_pipeline(flux_pipeline): generated_image, expected_size, config["pipeline_params"]["min_image_variance"] ) - print("\nāœ… IMAGE VALIDATION PASSED") + print("\n IMAGE VALIDATION PASSED") print(f" - Size: {image_validation['size']}") print(f" - Mode: {image_validation['mode']}") print(f" - Variance: {image_validation['variance']:.2f}") @@ -421,7 +421,7 @@ def test_flux_pipeline(flux_pipeline): if config["validation_checks"]["onnx_export"]: # Check if ONNX files exist (basic check) - print("\nšŸ” ONNX Export Validation:") + print("\n ONNX Export Validation:") for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: module_obj = getattr(pipeline, module_name, None) if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path: @@ -429,7 +429,7 @@ def test_flux_pipeline(flux_pipeline): if config["validation_checks"]["compilation"]: # Check if QPC files exist (basic check) - print("\nšŸ” Compilation Validation:") + print("\n Compilation Validation:") for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: module_obj = getattr(pipeline, module_name, None) if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path: diff --git a/tests/diffusers/test_wan.py b/tests/diffusers/test_wan.py new file mode 100644 index 000000000..f11db826b --- /dev/null +++ b/tests/diffusers/test_wan.py @@ -0,0 +1,535 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Test for wan pipeline +# TODO : 1. Add pytest for call method + 2. See if we reduce height and width + 3. Keep test for Sub fn as default once sdk supports +""" + +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import pytest +import safetensors.torch +import torch +from diffusers import WanPipeline +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ModulePerf, + QEffPipelineOutput, + calculate_latent_dimensions_with_frames, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants +from QEfficient.utils._utils import load_json +from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator + +# Test Configuration for 192x320 resolution with 1 layer +CONFIG_PATH = "tests/diffusers/wan_test_config.json" +INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) + + +def wan_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height: int = 192, + width: int = 320, + num_frames: int = 81, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 2, + guidance_scale: float = 1.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + use_onnx_subfunctions: bool = False, + parallel_compile: bool = True, + mad_tolerances: Dict[str, float] = None, +): + """ + Pipeline call function that replicates the exact flow of pipeline_wan.py.__call__() + while adding comprehensive MAD validation for transformer modules only. + + This function follows the EXACT same structure as QEffWanPipeline.__call__() + but adds MAD validation hooks for transformer testing. + """ + # Initialize MAD validator + mad_validator = MADValidator(tolerances=mad_tolerances) + + device = "cpu" + + # Step 1: Compile() (export and compile) + pipeline.cl, pipeline.latent_height, pipeline.latent_width, pipeline.latent_frames = ( + calculate_latent_dimensions_with_frames( + height, + width, + num_frames, + pipeline.model.vae.config.scale_factor_spatial, + pipeline.model.vae.config.scale_factor_temporal, + pipeline.patch_height, + pipeline.patch_width, + ) + ) + pipeline.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + height=height, + width=width, + num_frames=num_frames, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + set_module_device_ids(pipeline) + + # Step 2: Check inputs + pipeline.model.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % pipeline.model.vae.config.scale_factor_temporal != 1: + num_frames = ( + num_frames + // pipeline.model.vae.config.scale_factor_temporal + * pipeline.model.vae.config.scale_factor_temporal + + 1 + ) + num_frames = max(num_frames, 1) + + if pipeline.model.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + pipeline._guidance_scale = guidance_scale + pipeline._guidance_scale_2 = guidance_scale_2 + pipeline._attention_kwargs = attention_kwargs + pipeline._current_timestep = None + pipeline._interrupt = False + + # Step 3: Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Step 4: Encode input prompt(using CPU text encoder for now) + prompt_embeds, negative_prompt_embeds = pipeline.model.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=pipeline.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Get PyTorch reference prompt embeddings + # For standard WAN pipeline, CFG is determined by presence of negative prompts + do_classifier_free_guidance = negative_prompt is not None + pytorch_prompt_embeds, pytorch_negative_prompt_embeds = pytorch_pipeline.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = pipeline.transformer.model.transformer_high.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + pytorch_prompt_embeds = pytorch_prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + pytorch_negative_prompt_embeds = pytorch_negative_prompt_embeds.to(transformer_dtype) + + # Step 5: Prepare timesteps + pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = pipeline.scheduler.timesteps + + # Step 6: Prepare latent variables + num_channels_latents = pipeline.transformer.model.config.in_channels + latents = pipeline.model.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # Step 7: Setup transformer inference session + if pipeline.transformer.qpc_session is None: + pipeline.transformer.qpc_session = QAICInferenceSession( + str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids + ) + + output_buffer = { + "output": np.random.rand( + batch_size, + pipeline.cl, + constants.WAN_DIT_OUT_CHANNELS, + ).astype(np.int32), + } + pipeline.transformer.qpc_session.set_buffers(output_buffer) + transformer_perf = [] + + # Step 8: Denoising loop with transformer MAD validation + if pipeline.model.config.boundary_ratio is not None: + boundary_timestep = pipeline.model.config.boundary_ratio * pipeline.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order + + with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if pipeline._interrupt: + continue + + pipeline._current_timestep = t + + # Determine which transformer to use (high or low noise) + if boundary_timestep is None or t >= boundary_timestep: + # High-noise stage + current_model = pipeline.transformer.model.transformer_high + pytorch_current_model = pytorch_pipeline.transformer + model_type = torch.ones(1, dtype=torch.int64) + model_name = "transformer_high" + else: + # Low-noise stage + current_model = pipeline.transformer.model.transformer_low + pytorch_current_model = pytorch_pipeline.transformer_2 + model_type = torch.ones(2, dtype=torch.int64) + model_name = "transformer_low" + + latent_model_input = latents.to(transformer_dtype) + if pipeline.model.config.expand_timesteps: + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + batch_size, num_channels, num_frames, height, width = latents.shape + p_t, p_h, p_w = current_model.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Prepare transformer inputs + rotary_emb = current_model.rope(latent_model_input) + rotary_emb = torch.cat(rotary_emb, dim=0) + ts_seq_len = None + timestep = timestep.flatten() + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = current_model.condition_embedder( + timestep, prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len + ) + + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # Prepare inputs for QAIC inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy(), + "tsp": model_type.detach().numpy(), + } + + # PyTorch reference inference (standard WAN transformer has different signature) + noise_pred_torch = pytorch_current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=pytorch_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # QAIC inference + with current_model.cache_context("cond"): + start_transformer_step_time = time.time() + outputs = pipeline.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.time() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + hidden_states = torch.tensor(outputs["output"]) + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + # Transformer MAD validation + print(f" Performing MAD validation for {model_name} at step {i}...") + mad_validator.validate_module_mad( + noise_pred_torch.detach().cpu().numpy(), + noise_pred.detach().cpu().numpy(), + model_name, + f"step {i} (t={t.item():.1f})", + ) + + # Update latents using scheduler + latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + # Step 9: Decode latents to video (using CPU VAE for now) + if not output_type == "latent": + latents = latents.to(pipeline.vae_decode.dtype) + latents_mean = ( + torch.tensor(pipeline.vae_decode.config.latents_mean) + .view(1, pipeline.vae_decode.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(pipeline.vae_decode.config.latents_std).view( + 1, pipeline.vae_decode.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + video = pipeline.model.vae.decode(latents, return_dict=False)[0] + + video = pipeline.model.video_processor.postprocess_video(video.detach()) + else: + video = latents + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="transformer", perf=transformer_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=video, + ) + + +@pytest.fixture(scope="session") +def wan_pipeline(): + """Setup compiled WAN pipeline for testing with LoRA adapters and 2 layers total""" + config = INITIAL_TEST_CONFIG["model_setup"] + + def load_wan_lora(path: str): + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + # Download and load LoRA adapters + high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", + ) + low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", + ) + + # Load PyTorch reference pipeline + pytorch_pipeline = WanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + + # Load into the transformers + pytorch_pipeline.transformer.load_lora_adapter(load_wan_lora(high_noise_lora_path), adapter_name="high_noise") + pytorch_pipeline.transformer.set_adapters(["high_noise"], weights=[1.0]) + + pytorch_pipeline.transformer_2.load_lora_adapter(load_wan_lora(low_noise_lora_path), adapter_name="low_noise") + pytorch_pipeline.transformer_2.set_adapters(["low_noise"], weights=[1.0]) + + # ### for 2 layer model + pytorch_pipeline.transformer.config.num_layers = config["num_transformer_layers_high"] + pytorch_pipeline.transformer_2.config.num_layers = config["num_transformer_layers_low"] + original_blocks = pytorch_pipeline.transformer.blocks + org_blocks = pytorch_pipeline.transformer_2.blocks + pytorch_pipeline.transformer.blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pytorch_pipeline.transformer.config.num_layers)] + ) + pytorch_pipeline.transformer_2.blocks = torch.nn.ModuleList( + [org_blocks[i] for i in range(0, pytorch_pipeline.transformer_2.config.num_layers)] + ) + + # Load QEff WAN pipeline + pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + + # Load LoRA adapters into transformers + pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" + ) + pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" + ) + pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + + # Reduce to 1 layer (1 high, 1 low) for testing + pipeline.transformer.model.transformer_high.config.num_layers = config["num_transformer_layers_high"] + pipeline.transformer.model.transformer_low.config.num_layers = config["num_transformer_layers_low"] + + original_blocks_high = pipeline.transformer.model.transformer_high.blocks + original_blocks_low = pipeline.transformer.model.transformer_low.blocks + + pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( + [original_blocks_high[i] for i in range(0, config["num_transformer_layers_high"])] + ) + pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( + [original_blocks_low[i] for i in range(0, config["num_transformer_layers_low"])] + ) + + return pipeline, pytorch_pipeline + + +@pytest.mark.diffusion_models +@pytest.mark.on_qaic +@pytest.mark.wan +def test_wan_pipeline(wan_pipeline): + """ + Comprehensive WAN pipeline test that focuses on transformer validation: + - 192x320 resolution - 2 transformer layers total (1 high + 1 low) + - MAD validation for transformer modules only + - Functional video generation test + - Export/compilation checks for transformer + - Returns QEffPipelineOutput with performance metrics + """ + pipeline, pytorch_pipeline = wan_pipeline + config = INITIAL_TEST_CONFIG + + # Print test header + DiffusersTestUtils.print_test_header( + f"WAN PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_frames']} Frames, 2 Layers Total", + config, + ) + + # Test parameters + test_prompt = config["pipeline_params"]["test_prompt"] + num_inference_steps = config["pipeline_params"]["num_inference_steps"] + guidance_scale = config["pipeline_params"]["guidance_scale"] + guidance_scale_2 = config["pipeline_params"]["guidance_scale_2"] + max_sequence_length = config["pipeline_params"]["max_sequence_length"] + num_frames = config["model_setup"]["num_frames"] + + # Generate with MAD validation + generator = torch.manual_seed(42) + start_time = time.time() + + try: + # Run the pipeline with integrated MAD validation (focuses on transformer) + result = wan_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height=config["model_setup"]["height"], + width=config["model_setup"]["width"], + num_frames=num_frames, + prompt=test_prompt, + guidance_scale=guidance_scale, + guidance_scale_2=guidance_scale_2, + num_inference_steps=num_inference_steps, + max_sequence_length=max_sequence_length, + custom_config_path=CONFIG_PATH, + generator=generator, + mad_tolerances=config["mad_validation"]["tolerances"], + parallel_compile=True, + return_dict=True, + ) + + execution_time = time.time() - start_time + + # Validate video generation + if config["pipeline_params"]["validate_gen_video"]: + assert result is not None, "Pipeline returned None" + assert hasattr(result, "images"), "Result missing 'images' attribute" + assert len(result.images) > 0, "No video frames generated" + + generated_video = result.images[0] + assert len(generated_video) == num_frames, f"Expected {num_frames} frames, got {len(generated_video)}" + + # Validate first frame properties + first_frame = generated_video[0] + expected_size = (config["model_setup"]["width"], config["model_setup"]["height"]) + + # Convert numpy array to PIL Image if needed for validation + if isinstance(first_frame, np.ndarray): + from PIL import Image + + if first_frame.dtype != np.uint8: + first_frame = (first_frame * 255).astype(np.uint8) + if len(first_frame.shape) == 3 and first_frame.shape[0] == 3: + first_frame = first_frame.transpose(1, 2, 0) + first_frame = Image.fromarray(first_frame) + + # Validate video frame properties + frame_validation = DiffusersTestUtils.validate_image_generation( + first_frame, expected_size, config["pipeline_params"]["min_video_variance"] + ) + + print("\n VIDEO VALIDATION PASSED") + print(f" - Frame count: {len(generated_video)}") + print(f" - Frame size: {frame_validation['size']}") + print(f" - Frame mode: {frame_validation['mode']}") + print(f" - Frame variance: {frame_validation['variance']:.2f}") + print(f" - Mean pixel value: {frame_validation['mean_pixel_value']:.2f}") + + # Save result as video + frames = result.images[0] + export_to_video(frames, "test_wan_output_t2v.mp4", fps=16) + print("\n VIDEO SAVED: test_wan_output_t2v.mp4") + print(result) + + if config["validation_checks"]["onnx_export"]: + # Check if transformer ONNX file exists + print("\n ONNX Export Validation:") + if hasattr(pipeline.transformer, "onnx_path") and pipeline.transformer.onnx_path: + DiffusersTestUtils.check_file_exists(str(pipeline.transformer.onnx_path), "transformer ONNX") + + if config["validation_checks"]["compilation"]: + # Check if transformer QPC file exists + print("\n Compilation Validation:") + if hasattr(pipeline.transformer, "qpc_path") and pipeline.transformer.qpc_path: + DiffusersTestUtils.check_file_exists(str(pipeline.transformer.qpc_path), "transformer QPC") + + # Print test summary + print(f"\nTotal execution time: {execution_time:.4f}s") + print(" WAN TRANSFORMER TEST COMPLETED SUCCESSFULLY") + + except Exception as e: + print(f"\nTEST FAILED: {e}") + raise + + +if __name__ == "__main__": + # This allows running the test file directly for debugging + pytest.main([__file__, "-v", "-s", "-m", "wan"]) +# pytest tests/diffusers/test_wan.py -m wan -v -s --tb=short diff --git a/tests/diffusers/wan_test_config.json b/tests/diffusers/wan_test_config.json new file mode 100644 index 000000000..1ed36294a --- /dev/null +++ b/tests/diffusers/wan_test_config.json @@ -0,0 +1,63 @@ +{ + "model_setup": { + "height": 192, + "width": 320, + "num_frames": 81, + "num_transformer_layers_high": 1, + "num_transformer_layers_low": 1, + "use_onnx_subfunctions": false + }, + "mad_validation": { + "tolerances": { + "transformer_high": 0.3, + "transformer_low": 0.2 + } + }, + "pipeline_params": { + "test_prompt": "A cat walking in a garden", + "num_inference_steps": 2, + "guidance_scale": 1.0, + "guidance_scale_2": 1.0, + "max_sequence_length": 512, + "validate_gen_video": true, + "min_video_variance": 1.0 + }, + "validation_checks": { + "video_generation": true, + "onnx_export": true, + "compilation": true + }, + "modules": { + "transformer": { + "specializations": [ + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 1 + }, + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 2 + } + ], + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null + } + } + } +}