Skip to content

Commit aa88b63

Browse files
committed
Updated WAN and added pytest
Signed-off-by: vtirumal <[email protected]>
1 parent a036e97 commit aa88b63

File tree

15 files changed

+2662
-17
lines changed

15 files changed

+2662
-17
lines changed

QEfficient/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from QEfficient.compile.compile_helper import compile
3131
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
32+
from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline
3233
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
3334
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
3435
from QEfficient.peft import QEffAutoPeftModelForCausalLM
@@ -55,6 +56,7 @@
5556
"QEFFAutoModelForSpeechSeq2Seq",
5657
"QEFFCommonLoader",
5758
"QEffFluxPipeline",
59+
"QEffWanPipeline",
5860
]
5961

6062

QEfficient/base/modeling_qeff.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -526,18 +526,23 @@ class FeatureNotAvailableError(Exception):
526526
)
527527
try:
528528
subprocess.run(command, capture_output=True, check=True)
529+
# TODO: remove once compiler fix exit code (failing with error: Benchmark run failed, exit code 1)
529530
except subprocess.CalledProcessError as e:
530-
raise RuntimeError(
531-
"\n".join(
532-
[
533-
"Compilation failed!",
534-
f"Compiler command: {e.cmd}",
535-
f"Compiler exitcode: {e.returncode}",
536-
"Compiler stderr:",
537-
e.stderr.decode(),
538-
]
531+
# Check if exit code is 1 and programqpc.bin exists in qpc_path
532+
if e.returncode == 1 and qpc_path and (qpc_path / "programqpc.bin").is_file():
533+
logger.warning("Compiler exited with code 1, but programqpc.bin exists. Continuing...")
534+
else:
535+
raise RuntimeError(
536+
"\n".join(
537+
[
538+
"Compilation failed!",
539+
f"Compiler command: {e.cmd}",
540+
f"Compiler exitcode: {e.returncode}",
541+
"Compiler stderr:",
542+
e.stderr.decode(),
543+
]
544+
)
539545
)
540-
)
541546
# Dump JSON file with hashed parameters
542547
hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
543548
create_json(hashed_compile_params_path, compile_hash_params)

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
# -----------------------------------------------------------------------------
77

88
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
9+
from diffusers import AutoencoderKLWan
10+
911
from diffusers.models.transformers.transformer_flux import (
1012
FluxAttention,
1113
FluxAttnProcessor,
1214
FluxSingleTransformerBlock,
1315
FluxTransformer2DModel,
1416
FluxTransformerBlock,
1517
)
18+
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel, WanAttnProcessor, WanAttention
19+
1620
from torch import nn
1721

1822
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
@@ -30,6 +34,12 @@
3034
QEffFluxTransformerBlock,
3135
)
3236

37+
from QEfficient.diffusers.models.transformers.transformer_wan import (
38+
QEffWanTransformer3DModel,
39+
QEffWanAttnProcessor,
40+
QEffWanAttention,
41+
)
42+
3343

3444
class CustomOpsTransform(ModuleMappingTransform):
3545
_module_mapping = {
@@ -45,6 +55,9 @@ class AttentionTransform(ModuleMappingTransform):
4555
FluxTransformer2DModel: QEffFluxTransformer2DModel,
4656
FluxAttention: QEffFluxAttention,
4757
FluxAttnProcessor: QEffFluxAttnProcessor,
58+
WanAttnProcessor: QEffWanAttnProcessor,
59+
WanAttention: QEffWanAttention,
60+
WanTransformer3DModel: QEffWanTransformer3DModel,
4861
}
4962

5063

0 commit comments

Comments
 (0)