diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 7f63b34ca..b507363c3 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -6,7 +6,17 @@ # ----------------------------------------------------------------------------- import os -import warnings + +# ----------------------------------------------------------------------------- # +# For faster downloads via hf_transfer +# This code is put above import statements as this needs to be executed before +# hf_transfer is imported (will happen on line 15 via leading imports) +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +# DO NOT ADD ANY CODE ABOVE THIS LINE +# Please contact maintainers if you must edit this file above this line. +# ----------------------------------------------------------------------------- # +# Placeholder for all non-transformer models registered in QEfficient +import warnings # noqa: I001 import QEfficient.utils.model_registery # noqa: F401 from QEfficient.base import ( @@ -18,6 +28,7 @@ QEFFCommonLoader, ) from QEfficient.compile.compile_helper import compile +from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -25,6 +36,10 @@ from QEfficient.utils import custom_format_warning from QEfficient.utils.logging_utils import logger +# custom warning for the better logging experience +warnings.formatwarning = custom_format_warning + + # Users can use QEfficient.export for exporting models to ONNX export = qualcomm_efficient_converter __all__ = [ @@ -39,15 +54,9 @@ "QEFFAutoModelForImageTextToText", "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", + "QEffFluxPipeline", ] -# For faster downloads via hf_transfer -# This code is put above import statements as this needs to be executed before -# hf_transfer is imported (will happen on line 15 via leading imports) -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -# Placeholder for all non-transformer models registered in QEfficient -# custom warning for the better logging experience -warnings.formatwarning = custom_format_warning # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ef7e83adf..2c98a83f3 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,7 +8,6 @@ import gc import inspect import logging -import re import shutil import subprocess import warnings @@ -21,26 +20,21 @@ from QEfficient.base.onnx_transforms import ( BaseOnnxTransform, - CustomOpTransform, OnnxTransformPipeline, - RenameFunctionOutputsTransform, ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.transformers.cache_utils import InvalidIndexProvider -from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import ( constants, create_json, create_model_params, dump_qconfig, - export_wrapper, generate_mdp_partition_config, hash_dict_params, load_json, ) -from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches +from QEfficient.utils.export_utils import export_wrapper logger = logging.getLogger(__name__) @@ -66,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) + self.prefill_onnx_path: Optional[str] = None self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None self.qpc_session: Optional[QAICInferenceSession] = None @@ -125,9 +120,35 @@ def _model_offloaded_check(self) -> None: logger.error(error_msg) raise RuntimeError(error_msg) + @property + def model_name(self) -> str: + """ + Get the model class name without QEff/QEFF prefix. + + This property extracts the underlying model's class name and removes + any QEff or QEFF prefix that may have been added during wrapping. + + Returns: + str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel") + """ + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + @property @abstractmethod - def model_name(self) -> str: ... + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + This is an abstract property that must be implemented by all subclasses. + Typically returns: self.model.config.__dict__ + + Returns: + Dict: The configuration dictionary of the underlying model + """ + pass @abstractmethod def export(self, export_dir: Optional[str] = None) -> Path: @@ -184,11 +205,11 @@ def _export( example_inputs: Dict[str, torch.Tensor], output_names: List[str], dynamic_axes: Dict[str, Dict[int, str]], - export_kwargs: Optional[Dict[str, any]] = None, onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, - use_onnx_subfunctions: bool = False, + prefill_only: Optional[bool] = False, + **export_kwargs, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -213,11 +234,16 @@ def _export( instance using from_pretrained() for re-export. """ + # TODO: Hack for retain_full_kv, handle this outside + export_kwargs.pop("retain_full_kv", None) onnx_path = export_dir / f"{self.model_name}.onnx" # Return early if ONNX already exists if onnx_path.is_file(): - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path # check if the model is in meta state or weights are offloaded @@ -253,19 +279,6 @@ def _export( input_names.append(param) try: - # Initialize the registry with your custom ops - export_kwargs = {} if export_kwargs is None else export_kwargs - if use_onnx_subfunctions: - warnings.warn( - "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." - ) - apply_torch_patches() - InvalidIndexProvider.SUBFUNC_ENABLED = True - output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] - export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) - self._onnx_transforms.append(RenameFunctionOutputsTransform) - self._onnx_transforms.append(CustomOpTransform) - torch.onnx.export( self.model, (example_inputs,), @@ -309,15 +322,42 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) - if use_onnx_subfunctions: - undo_torch_patches() - InvalidIndexProvider.SUBFUNC_ENABLED = False - self._onnx_transforms.remove(CustomOpTransform) - self._onnx_transforms.remove(RenameFunctionOutputsTransform) - - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path + def get_onnx_path( + self, + prefill_only: Optional[bool] = False, + enable_chunking: Optional[bool] = False, + specializations: Optional[List[Dict[str, int]]] = None, + offload_pt_weights: Optional[bool] = True, + use_onnx_subfunctions: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + kwargs = { + "offload_pt_weights": offload_pt_weights, + "use_onnx_subfunctions": use_onnx_subfunctions, + "retain_full_kv": retain_full_kv, + } + if prefill_only: + if self.prefill_onnx_path is None: + kwargs.update( + { + "prefill_only": prefill_only, + "prefill_seq_len": specializations[0].get("seq_len"), + "enable_chunking": enable_chunking, + } + ) + self.export(**kwargs) + return self.prefill_onnx_path + else: + if self.onnx_path is None: + self.export(**kwargs) + return self.onnx_path + @dump_qconfig def _compile( self, @@ -332,6 +372,10 @@ def _compile( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, use_onnx_subfunctions: bool = False, + prefill_only: Optional[str] = None, + offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -357,11 +401,18 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - - if onnx_path is None and self.onnx_path is None: - self.export(use_onnx_subfunctions=use_onnx_subfunctions) - - onnx_path = Path(onnx_path or self.onnx_path) + onnx_path = Path( + onnx_path + if onnx_path + else self.get_onnx_path( + prefill_only, + enable_chunking, + specializations, + offload_pt_weights, + use_onnx_subfunctions, + retain_full_kv, + ) + ) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): @@ -423,6 +474,7 @@ def _compile( "mdp_ts_num_devices": mdp_ts_num_devices, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, + "prefill_only": prefill_only, } compile_hash = hash_dict_params(compile_hash_params) @@ -462,6 +514,16 @@ def _compile( command.append(f"-aic-binary-dir={qpc_path}") logger.info(f"Running compiler: {' '.join(command)}") + if use_onnx_subfunctions: + + class FeatureNotAvailableError(Exception): + pass + + exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}' + raise FeatureNotAvailableError( + "ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model." + + f"\nRun following command manually with assert compiler:\n{exec_command}" + ) try: subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: @@ -482,5 +544,4 @@ def _compile( logger.info("Hashed parameters exported successfully.") self.qpc_path = qpc_path - return qpc_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 945850c50..16697cec9 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -19,16 +19,20 @@ from QEfficient.customop.ctx_scatter_gather import ( CtxGather, CtxGather3D, + CtxGatherBlockedKV, CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFuncBlockedKV, CtxScatter, CtxScatter3D, CtxScatterFunc, CtxScatterFunc3D, ) from QEfficient.customop.ctx_scatter_gather_cb import ( + CtxGatherBlockedKVCB, CtxGatherCB, CtxGatherCB3D, + CtxGatherFuncBlockedKVCB, CtxGatherFuncCB, CtxGatherFuncCB3D, CtxScatterCB, @@ -91,10 +95,12 @@ class CustomOpTransform(BaseOnnxTransform): "CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D), "CtxGatherFunc": (CtxGatherFunc, CtxGather), "CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D), - "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), - "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), + "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV), + "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB), + "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), + "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), } @classmethod diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c7dc8639a..7b15effe7 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function): def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, head_indices, ctx_indices] @staticmethod diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 8a06bc2b1..c15b60810 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function): def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices) return data[batch_indices, head_indices, ctx_indices] @staticmethod diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md new file mode 100644 index 000000000..40d45e984 --- /dev/null +++ b/QEfficient/diffusers/README.md @@ -0,0 +1,95 @@ + +
+ + +# **Diffusion Models on Qualcomm Cloud AI 100** + + +
+ +### šŸŽØ **Experience the Future of AI Image Generation** + +* Optimized for Qualcomm Cloud AI 100* + +Sample Output + +**Generated with**: `black-forest-labs/FLUX.1-schnell` • `"A girl laughing"` • 4 steps • 0.0 guidance scale • ⚔ + + + +
+ + + +[![Diffusers](https://img.shields.io/badge/Diffusers-0.35.1-orange.svg)](https://github.com/huggingface/diffusers) +
+ +--- + +## ✨ Overview + +QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware. + +## šŸ› ļø Installation + +### Prerequisites + +Ensure you have Python 3.8+ and the required dependencies: + +```bash +# Create Python virtual environment (Recommended Python 3.10) +sudo apt install python3.10-venv +python3.10 -m venv qeff_env +source qeff_env/bin/activate +pip install -U pip +``` + +### Install QEfficient + +```bash +# Install from GitHub (includes diffusers support) +pip install git+https://github.com/quic/efficient-transformers + +# Or build from source +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install build wheel +python -m build --wheel --outdir dist +pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl +``` + +--- + +## šŸŽÆ Supported Models +- āœ… [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) + +--- + + +## šŸ“š Examples + +Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory: + +--- + +## šŸ¤ Contributing + +We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details. + + + +--- + +## šŸ™ Acknowledgments + +- **HuggingFace Diffusers**: For the excellent foundation library +- **Stability AI**: For the amazing Stable Diffusion models +--- + +## šŸ“ž Support + +- šŸ“– **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/) +- šŸ› **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues) + +--- + diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py new file mode 100644 index 000000000..933832ed8 --- /dev/null +++ b/QEfficient/diffusers/models/normalization.py @@ -0,0 +1,40 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from typing import Optional, Tuple + +import torch +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +class QEffAdaLayerNormZero(AdaLayerNormZero): + def forward( + self, + x: torch.Tensor, + shift_msa: Optional[torch.Tensor] = None, + scale_msa: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + + +class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle): + def forward( + self, + x: torch.Tensor, + scale_msa: Optional[torch.Tensor] = None, + shift_msa: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + + +class QEffAdaLayerNormContinuous(AdaLayerNormContinuous): + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = conditioning_embedding + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py new file mode 100644 index 000000000..d3c84ee63 --- /dev/null +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -0,0 +1,56 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm +from diffusers.models.transformers.transformer_flux import ( + FluxAttention, + FluxAttnProcessor, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, +) +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.normalization import ( + QEffAdaLayerNormContinuous, + QEffAdaLayerNormZero, + QEffAdaLayerNormZeroSingle, +) +from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxAttention, + QEffFluxAttnProcessor, + QEffFluxSingleTransformerBlock, + QEffFluxTransformer2DModel, + QEffFluxTransformerBlock, +) + + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + RMSNorm: CustomRMSNormAIC, + nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm + } + + +class AttentionTransform(ModuleMappingTransform): + _module_mapping = { + FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock, + FluxTransformerBlock: QEffFluxTransformerBlock, + FluxTransformer2DModel: QEffFluxTransformer2DModel, + FluxAttention: QEffFluxAttention, + FluxAttnProcessor: QEffFluxAttnProcessor, + } + + +class NormalizationTransform(ModuleMappingTransform): + _module_mapping = { + AdaLayerNormZero: QEffAdaLayerNormZero, + AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle, + AdaLayerNormContinuous: QEffAdaLayerNormContinuous, + } diff --git a/QEfficient/diffusers/models/transformers/__init__.py b/QEfficient/diffusers/models/transformers/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py new file mode 100644 index 000000000..5cb44af45 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_flux.py @@ -0,0 +1,327 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_flux import ( + FluxAttention, + FluxAttnProcessor, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, + _get_qkv_projections, +) + +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_rotary_emb( + x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + cos, sin = freqs_cis # [S, D] + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + cos, sin = cos.to(x.device), sin.to(x.device) + B, S, H, D = x.shape + x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +class QEffFluxAttnProcessor(FluxAttnProcessor): + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "QEffFluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = qeff_apply_rotary_emb(query, image_rotary_emb) + key = qeff_apply_rotary_emb(key, image_rotary_emb) + + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, backend=self._attention_backend + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class QEffFluxAttention(FluxAttention): + def __qeff_init__(self): + processor = QEffFluxAttnProcessor() + self.processor = processor + + +class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + shift_msa, scale_msa, gate = torch.split(temb, 1) + residual = hidden_states + norm_hidden_states = self.norm(hidden_states, scale_msa, shift_msa) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + # if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(torch.finfo(torch.float32).min, torch.finfo(torch.float32).max) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class QEffFluxTransformerBlock(FluxTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + temb1 = tuple(torch.split(temb[:6], 1)) + temb2 = tuple(torch.split(temb[6:], 1)) + norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1]) + gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:] + + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1]) + + c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:] + + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + # if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QEffFluxTransformer2DModel(FluxTransformer2DModel): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + adaln_emb: torch.Tensor = None, + adaln_single_emb: torch.Tensor = None, + adaln_out: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_single_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, adaln_out) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/configs/flux_config.json b/QEfficient/diffusers/pipelines/configs/flux_config.json new file mode 100644 index 000000000..73b92265f --- /dev/null +++ b/QEfficient/diffusers/pipelines/configs/flux_config.json @@ -0,0 +1,99 @@ +{ + "description": "Default configuration for Flux pipeline", + + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only": true + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + } + } +} diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py new file mode 100644 index 000000000..511746469 --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,854 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +# TODO: Pipeline Architecture Improvements +# 1. Introduce QEffDiffusionPipeline base class to provide unified export, compile, +# and inference APIs across all diffusion pipelines, promoting code reusability +# and consistent interface design. +# 2. Implement persistent QPC session management strategy to retain/drop compiled model +# sessions in memory across all pipeline modules. + +import os +import time +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from tqdm import tqdm + +from QEfficient.diffusers.pipelines.pipeline_module import ( + QEffFluxTransformerModel, + QEffTextEncoder, + QEffVAE, +) +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ONNX_SUBFUNCTION_MODULE, + ModulePerf, + QEffPipelineOutput, + calculate_compressed_latent_dimension, + compile_modules_parallel, + compile_modules_sequential, + config_manager, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.logging_utils import logger + + +class QEffFluxPipeline: + """ + QEfficient-optimized Flux pipeline for high-performance text-to-image generation on Qualcomm AI hardware. + + This pipeline provides an optimized implementation of the Flux diffusion model specifically designed + for deployment on Qualcomm AI Cloud (QAIC) devices. It wraps the original HuggingFace Flux model + components with QEfficient-optimized versions that can be exported to ONNX format and compiled + into Qualcomm Program Container (QPC) files for efficient inference. + + The pipeline supports the complete Flux workflow including: + - Dual text encoding with CLIP and T5 encoders + - Transformer-based denoising with adaptive layer normalization + - VAE decoding for final image generation + - Performance monitoring and optimization + + Attributes: + text_encoder (QEffTextEncoder): Optimized CLIP text encoder for pooled embeddings + text_encoder_2 (QEffTextEncoder): Optimized T5 text encoder for sequence embeddings + transformer (QEffFluxTransformerModel): Optimized Flux transformer for denoising + vae_decode (QEffVAE): Optimized VAE decoder for latent-to-image conversion + modules (Dict[str, Any]): Dictionary of all pipeline modules for batch operations + model (FluxPipeline): Original HuggingFace Flux model reference + tokenizer: CLIP tokenizer for text preprocessing + scheduler: Diffusion scheduler for timestep management + + Example: + >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> images = pipeline( + ... prompt="A beautiful sunset over mountains", + ... height=512, + ... width=512, + ... num_inference_steps=28 + ... ) + >>> images.images[0].save("generated_image.png") + """ + + _hf_auto_class = FluxPipeline + + def __init__(self, model, *args, **kwargs): + """ + Initialize the QEfficient Flux pipeline. + + This pipeline provides an optimized implementation of the Flux text-to-image model + for deployment on Qualcomm AI hardware. It wraps the original HuggingFace Flux model + components with QEfficient-optimized versions that can be exported to ONNX and compiled + for QAIC devices. + + Args: + model: Pre-loaded FluxPipeline model + **kwargs: Additional arguments including height and width + """ + + # Wrap model components with QEfficient optimized versions + self.model = model + self.text_encoder = QEffTextEncoder(model.text_encoder) + self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2) + self.transformer = QEffFluxTransformerModel(model.transformer) + self.vae_decode = QEffVAE(model.vae, "decoder") + + # Store all modules in a dictionary for easy iteration during export/compile + self.modules = { + "text_encoder": self.text_encoder, + "text_encoder_2": self.text_encoder_2, + "transformer": self.transformer, + "vae_decoder": self.vae_decode, + } + + # Copy tokenizers and scheduler from the original model + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.text_encoder_2.tokenizer = model.tokenizer_2 + self.tokenizer_max_length = model.tokenizer_max_length + self.scheduler = model.scheduler + + # Override VAE forward method to use decode directly + self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode( + latent_sample, return_dict + ) + + # Sync max position embeddings between text encoders + self.text_encoder_2.model.config.max_position_embeddings = ( + self.text_encoder.model.config.max_position_embeddings + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + **kwargs, + ): + """ + Load a pretrained Flux model from HuggingFace Hub or local path and wrap it with QEfficient optimizations. + + This class method provides a convenient way to instantiate a QEffFluxPipeline from a pretrained + Flux model. It automatically loads the base FluxPipeline model in float32 precision on CPU + and wraps all components with QEfficient-optimized versions for QAIC deployment. + + Args: + pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier + (e.g., "black-forest-labs/FLUX.1-schnell") or a local path to a saved model directory. + **kwargs: Additional keyword arguments passed to FluxPipeline.from_pretrained(). + + Returns: + QEffFluxPipeline: A fully initialized pipeline instance with QEfficient-optimized components + ready for export, compilation, and inference on QAIC devices. + + Raises: + ValueError: If the model path is invalid or model cannot be loaded + OSError: If there are issues accessing the model files + RuntimeError: If model initialization fails + + Example: + >>> # Load from HuggingFace Hub + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> + >>> # Load from local path + >>> pipeline = QEffFluxPipeline.from_pretrained("/path/to/local/flux/model") + >>> + >>> # Load with custom cache directory + >>> pipeline = QEffFluxPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-dev", + ... cache_dir="/custom/cache/dir" + ... ) + """ + # Load the base Flux model in float32 on CPU + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + device_map="cpu", + **kwargs, + ) + + return cls( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) + + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + """ + Export all pipeline modules to ONNX format for deployment preparation. + + This method systematically exports each pipeline component (CLIP text encoder, T5 text encoder, + Flux transformer, and VAE decoder) to ONNX format. Each module is exported with its specific + configuration including dynamic axes, input/output specifications, and optimization settings. + + The export process prepares the models for subsequent compilation to QPC format, enabling + efficient inference on QAIC hardware. ONNX subfunctions can be used for certain modules + to optimize memory usage and performance. + + Args: + export_dir (str, optional): Target directory for saving ONNX model files. If None, + uses the default export directory structure based on model name and configuration. + The directory will be created if it doesn't exist. + use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction + optimization for supported modules. This can optimize thegraph and + improve compilation efficiency for models like the transformer. + + Returns: + str: Absolute path to the export directory containing all ONNX model files. + Each module will have its own subdirectory with the exported ONNX file. + + Raises: + RuntimeError: If ONNX export fails for any module + OSError: If there are issues creating the export directory or writing files + ValueError: If module configurations are invalid + + Note: + - All models are exported in float32 precision for maximum compatibility + - Dynamic axes are configured to support variable batch sizes and sequence lengths + - The export process may take several minutes depending on model size + - Exported ONNX files can be large (several GB for complete pipeline) + + Example: + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> export_path = pipeline.export( + ... export_dir="/path/to/export", + ... use_onnx_subfunctions=True + ... ) + >>> print(f"Models exported to: {export_path}") + """ + for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"): + # Get ONNX export configuration for this module + example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params() + + export_params = { + "inputs": example_inputs, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "export_dir": export_dir, + } + + if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE: + export_params["use_onnx_subfunctions"] = True + + module_obj.export(**export_params) + + @staticmethod + def get_default_config_path() -> str: + """ + Get the absolute path to the default Flux pipeline configuration file. + + Returns: + str: Absolute path to the flux_config.json file containing default pipeline + configuration settings for compilation and device allocation. + """ + return "QEfficient/diffusers/pipelines/configs/flux_config.json" + + def compile( + self, + compile_config: Optional[str] = None, + parallel: bool = False, + height: int = 512, + width: int = 512, + use_onnx_subfunctions: bool = False, + ) -> None: + """ + Compile ONNX models into optimized QPC format for deployment on Qualcomm AI hardware. + + Args: + compile_config (str, optional): Path to a JSON configuration file containing + compilation settings, device mappings, and optimization parameters. If None, + uses the default configuration from get_default_config_path(). + parallel (bool, default=False): Compilation mode selection: + - True: Compile modules in parallel using ThreadPoolExecutor for faster processing + - False: Compile modules sequentially for lower resource usage + height (int, default=512): Target image height in pixels. + width (int, default=512): Target image width in pixels. + use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX + subfunctions before compilation. + + Raises: + RuntimeError: If compilation fails for any module or if QAIC compiler is not available + FileNotFoundError: If ONNX models haven't been exported or config file is missing + ValueError: If configuration parameters are invalid + OSError: If there are issues with file I/O during compilation + + Example: + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> # Sequential compilation with default config + >>> pipeline.compile(height=1024, width=1024) + >>> + >>> # Parallel compilation with custom config + >>> pipeline.compile( + ... compile_config="/path/to/custom_config.json", + ... parallel=True, + ... height=512, + ... width=512 + ... ) + """ + # Ensure all modules are exported to ONNX before compilation + if any( + path is None + for path in [ + self.text_encoder.onnx_path, + self.text_encoder_2.onnx_path, + self.transformer.onnx_path, + self.vae_decode.onnx_path, + ] + ): + self.export(use_onnx_subfunctions=use_onnx_subfunctions) + + # Load compilation configuration + config_manager(self, config_source=compile_config) + + # Calculate compressed latent dimension using utility function + cl, latent_height, latent_width = calculate_compressed_latent_dimension( + height, width, self.model.vae_scale_factor + ) + + # Prepare dynamic specialization updates based on image dimensions + specialization_updates = { + "transformer": {"cl": cl}, + "vae_decoder": { + "latent_height": latent_height, + "latent_width": latent_width, + }, + } + + # Use generic utility functions for compilation + if parallel: + compile_modules_parallel(self.modules, self.custom_config, specialization_updates) + else: + compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device_ids: Optional[List[int]] = None, + ): + """ + Encode text prompts using the T5 text encoder for detailed semantic understanding. + + T5 provides rich sequence embeddings that capture fine-grained text details, + complementing CLIP's global representation in Flux's dual encoder setup. + + Args: + prompt (str or List[str]): Input prompt(s) to encode + num_images_per_prompt (int): Number of images to generate per prompt + max_sequence_length (int): Maximum token sequence length (default: 512) + device_ids (List[int], optional): QAIC device IDs for inference + + Returns: + tuple: (prompt_embeds, inference_time) + - prompt_embeds (torch.Tensor): Encoded embeddings [batch*num_images, seq_len, 4096] + - inference_time (float): T5 encoder inference time in seconds + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + # Tokenize prompts with padding and truncation + text_inputs = self.text_encoder_2.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + # Check for truncation and warn user + untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.text_encoder_2.tokenizer.batch_decode( + untruncated_ids[:, self.text_encoder_2.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because `max_sequence_length` is set to " + f"{self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}" + ) + + # Initialize QAIC inference session if not already created + if self.text_encoder_2.qpc_session is None: + self.text_encoder_2.qpc_session = QAICInferenceSession( + str(self.text_encoder_2.qpc_path), device_ids=device_ids + ) + + # Allocate output buffers for QAIC inference + text_encoder_2_output = { + "last_hidden_state": np.random.rand( + batch_size, max_sequence_length, self.text_encoder_2.model.config.d_model + ).astype(np.int32), + } + self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output) + + # Prepare input for QAIC inference + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + + # Run T5 encoder inference and measure time + start_t5_time = time.perf_counter() + prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"]) + end_t5_time = time.perf_counter() + text_encoder_2_perf = end_t5_time - start_t5_time + + # Duplicate embeddings for multiple images per prompt + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, text_encoder_2_perf + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device_ids: Optional[List[int]] = None, + ): + """ + Encode text prompts using the CLIP text encoder for global semantic representation. + + CLIP provides pooled embeddings that capture high-level semantic meaning, + working alongside T5's detailed sequence embeddings in Flux's dual encoder setup. + + Args: + prompt (str or List[str]): Input prompt(s) to encode + num_images_per_prompt (int): Number of images to generate per prompt + device_ids (List[int], optional): QAIC device IDs for inference + + Returns: + tuple: (pooled_prompt_embeds, inference_time) + - pooled_prompt_embeds (torch.Tensor): Pooled embeddings [batch*num_images, 768] + - inference_time (float): CLIP encoder inference time in seconds + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + # Tokenize prompts + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + # Check for truncation and warn user + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to " + f"{self.tokenizer_max_length} tokens: {removed_text}" + ) + + # Initialize QAIC inference session if not already created + if self.text_encoder.qpc_session is None: + self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids) + + # Allocate output buffers for QAIC inference + text_encoder_output = { + "last_hidden_state": np.random.rand( + batch_size, self.tokenizer_max_length, self.text_encoder.model.config.hidden_size + ).astype(np.float32), + "pooler_output": np.random.rand(batch_size, self.text_encoder.model.config.hidden_size).astype(np.int32), + } + self.text_encoder.qpc_session.set_buffers(text_encoder_output) + + # Prepare input for QAIC inference + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + + # Run CLIP encoder inference and measure time + start_text_encoder_time = time.perf_counter() + aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input) + end_text_encoder_time = time.perf_counter() + text_encoder_perf = end_text_encoder_time - start_text_encoder_time + # Extract pooled output (used for conditioning in Flux) + prompt_embeds = torch.tensor(aic_embeddings["pooler_output"]) + + # Duplicate embeddings for multiple images per prompt + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, text_encoder_perf + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + """ + Encode text prompts using Flux's dual text encoder architecture. + + Flux employs both CLIP and T5 encoders for comprehensive text understanding: + - CLIP provides pooled embeddings for global semantic conditioning + - T5 provides detailed sequence embeddings for fine-grained text control + + Args: + prompt (str or List[str]): Primary prompt(s) for both encoders + prompt_2 (str or List[str], optional): Secondary prompt(s) for T5. If None, uses primary prompt + num_images_per_prompt (int): Number of images to generate per prompt + prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 embeddings + pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings + max_sequence_length (int): Maximum sequence length for T5 tokenization + + Returns: + tuple: (prompt_embeds, pooled_prompt_embeds, text_ids, encoder_perf_times) + - prompt_embeds (torch.Tensor): T5 sequence embeddings [batch*num_images, seq_len, 4096] + - pooled_prompt_embeds (torch.Tensor): CLIP pooled embeddings [batch*num_images, 768] + - text_ids (torch.Tensor): Position IDs for text tokens [seq_len, 3] + - encoder_perf_times (List[float]): Performance times [CLIP_time, T5_time] + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + # Use primary prompt for both encoders if secondary not provided + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # Encode with CLIP (returns pooled embeddings) + pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds( + prompt=prompt, + device_ids=self.text_encoder.device_ids, + num_images_per_prompt=num_images_per_prompt, + ) + + # Encode with T5 (returns sequence embeddings) + prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device_ids=self.text_encoder_2.device_ids, + ) + + # Create text position IDs (required by Flux transformer) + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + + return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf] + + def __call__( + self, + height: int = 512, + width: int = 512, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + parallel_compile: bool = False, + use_onnx_subfunctions: bool = False, + ): + """ + Generate images from text prompts using the QEfficient-optimized Flux pipeline on QAIC hardware. + + This is the main entry point for text-to-image generation. It orchestrates the complete Flux + diffusion pipeline optimized for Qualcomm AI Cloud devices. + + Args: + height (int, optional): Target image height in pixels. Must be divisible by 8. Default: 512. + width (int, optional): Target image width in pixels. Must be divisible by 8. Default: 512. + prompt (str or List[str]): Primary text prompt(s) describing the desired image(s). + Required unless `prompt_embeds` is provided. + prompt_2 (str or List[str], optional): Secondary prompt for T5 encoder. If None, uses `prompt`. + negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid. + Only used when `true_cfg_scale > 1.0`. + negative_prompt_2 (str or List[str], optional): Secondary negative prompt for T5. If None, uses `negative_prompt`. + true_cfg_scale (float, optional): True classifier-free guidance scale. Values > 1.0 enable + negative prompting. Default: 1.0 (disabled). + num_inference_steps (int, optional): Number of denoising steps. Default: 28. + timesteps (List[int], optional): Custom timestep schedule. If provided, overrides `num_inference_steps`. + guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.5. + num_images_per_prompt (int, optional): Number of images to generate per prompt. Default: 1. + generator (torch.Generator or List[torch.Generator], optional): Random generator for reproducibility. + latents (torch.FloatTensor, optional): Pre-generated latent tensors. If None, random latents are generated. + prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 text embeddings. Shape: [batch, seq_len, 4096]. + pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings. Shape: [batch, 768]. + negative_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative T5 embeddings. + negative_pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative CLIP embeddings. + output_type (str, optional): Output format. Options: "pil" (default), "np", or "latent". + callback_on_step_end (Callable, optional): Callback function executed after each denoising step. + callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. Default: ["latents"]. + max_sequence_length (int, optional): Maximum token sequence length for T5 encoder. Default: 512. + custom_config_path (str, optional): Path to custom JSON configuration file for compilation settings. + parallel_compile (bool, optional): Whether to compile modules in parallel. Default: False. + use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. Default: False. + + Returns: + QEffPipelineOutput: A dataclass containing: + - images: Generated image(s) in the format specified by `output_type` + - pipeline_module: Performance metrics for each pipeline component (text encoders, transformer, VAE) + + Raises: + ValueError: If input validation fails or parameters are incompatible. + RuntimeError: If compilation fails or QAIC devices are unavailable. + FileNotFoundError: If custom config file is specified but not found. + + Example: + >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> result = pipeline( + ... prompt="A serene mountain landscape at sunset", + ... height=1024, + ... width=1024, + ... num_inference_steps=28, + ... guidance_scale=7.5 + ... ) + >>> result.images[0].save("mountain_sunset.png") + >>> print(f"Transformer inference time: {sum(result.pipeline_module[2].perf):.2f}s") + """ + device = self.model._execution_device + + if height is None or width is None: + logger.warning("Height or width is None. Setting default values of 512 for both dimensions.") + + self.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + height=height, + width=width, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + # Set device IDs for all modules based on configuration + set_module_device_ids(self) + + # Validate all inputs + self.model.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # Step 2: Determine batch size from inputs + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Step 3: Encode prompts with both text encoders + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + (prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # Encode negative prompts if using true classifier-free guidance + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # Step 4: Prepare timesteps for denoising + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Step 5: Prepare initial latents + num_channels_latents = self.transformer.model.config.in_channels // 4 + latents, latent_image_ids = self.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Step 6: Calculate compressed latent dimension for transformer buffer allocation + cl, _, _ = calculate_compressed_latent_dimension(height, width, self.model.vae_scale_factor) + + # Initialize transformer inference session + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession( + str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + ) + + # Allocate output buffer for transformer + output_buffer = { + "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32), + } + self.transformer.qpc_session.set_buffers(output_buffer) + + transformer_perf = [] + self.scheduler.set_begin_index(0) + + # Step 7: Denoising loop + with self.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Prepare timestep embedding + timestep = t.expand(latents.shape[0]).to(latents.dtype) + temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds) + + # Compute AdaLN (Adaptive Layer Normalization) embeddings for dual transformer blocks + adaln_emb = [] + for block_idx in range(len(self.transformer.model.transformer_blocks)): + block = self.transformer.model.transformer_blocks[block_idx] + # Process through norm1 and norm1_context + f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1) + f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1) + adaln_emb.append(torch.cat(list(f1) + list(f2))) + adaln_dual_emb = torch.stack(adaln_emb) + + # Compute AdaLN embeddings for single transformer blocks + adaln_emb = [] + for block_idx in range(len(self.transformer.model.single_transformer_blocks)): + block = self.transformer.model.single_transformer_blocks[block_idx] + f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1) + adaln_emb.append(torch.cat(list(f1))) + adaln_single_emb = torch.stack(adaln_emb) + + # Compute output AdaLN embedding + temp = self.transformer.model.norm_out + adaln_out = temp.linear(temp.silu(temb)) + + # Normalize timestep to [0, 1] range + timestep = timestep / 1000 + + # Prepare all inputs for transformer inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": prompt_embeds.detach().numpy(), + "pooled_projections": pooled_prompt_embeds.detach().numpy(), + "timestep": timestep.detach().numpy(), + "img_ids": latent_image_ids.detach().numpy(), + "txt_ids": text_ids.detach().numpy(), + "adaln_emb": adaln_dual_emb.detach().numpy(), + "adaln_single_emb": adaln_single_emb.detach().numpy(), + "adaln_out": adaln_out.detach().numpy(), + } + + # Run transformer inference and measure time + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + noise_pred = torch.from_numpy(outputs["output"]) + + # Update latents using scheduler (x_t -> x_t-1) + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Handle dtype mismatch (workaround for MPS backend bug) + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + # Execute callback if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # Step 8: Decode latents to images (unless output_type is "latent") + if output_type == "latent": + image = latents + else: + # Unpack and denormalize latents + latents = self.model._unpack_latents(latents, height, width, self.model.vae_scale_factor) + latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor + + # Initialize VAE decoder inference session + if self.vae_decode.qpc_session is None: + self.vae_decode.qpc_session = QAICInferenceSession( + str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.int32)} + self.vae_decode.qpc_session.set_buffers(output_buffer) + + # Run VAE decoder inference and measure time + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.perf_counter() + image = self.vae_decode.qpc_session.run(inputs) + end_decode_time = time.perf_counter() + vae_decode_perf = end_decode_time - start_decode_time + + # Post-process image + image_tensor = torch.from_numpy(image["sample"]) + image = self.model.image_processor.postprocess(image_tensor, output_type=output_type) + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]), + ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]), + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decode_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=image, + ) diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py new file mode 100644 index 000000000..6d9243fdc --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -0,0 +1,481 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + +from QEfficient.base.modeling_qeff import QEFFBaseModel +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.diffusers.models.pytorch_transforms import ( + AttentionTransform, + CustomOpsTransform, + NormalizationTransform, +) +from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxSingleTransformerBlock, + QEffFluxTransformerBlock, +) +from QEfficient.transformers.models.pytorch_transforms import ( + T5ModelTransform, +) +from QEfficient.utils import constants + + +class QEffTextEncoder(QEFFBaseModel): + """ + Wrapper for text encoder models with ONNX export and QAIC compilation capabilities. + + This class handles text encoder models (CLIP, T5) with specific transformations and + optimizations for efficient inference on Qualcomm AI hardware. It applies custom + PyTorch and ONNX transformations to prepare models for deployment. + + Attributes: + model (nn.Module): The wrapped text encoder model (deep copy of original) + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform, T5ModelTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying text encoder model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the text encoder wrapper. + + Args: + model (nn.Module): The text encoder model to wrap (CLIP or T5) + """ + super().__init__(model) + self.model = model + + def get_onnx_params(self) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the text encoder. + + Creates example inputs, dynamic axes specifications, and output names + tailored to the specific text encoder type (CLIP vs T5). + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # Create example input with max sequence length + example_inputs = { + "input_ids": torch.zeros((bs, self.model.config.max_position_embeddings), dtype=torch.int64), + } + + # Define which dimensions can vary at runtime + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}} + + # T5 only outputs hidden states, CLIP outputs both hidden states and pooled output + if self.model.__class__.__name__ == "T5EncoderModel": + output_names = ["last_hidden_state"] + else: + output_names = ["last_hidden_state", "pooler_output"] + example_inputs["output_hidden_states"] = False + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + ) -> str: + """ + Export the text encoder model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffUNet(QEFFBaseModel): + """ + Wrapper for UNet models with ONNX export and QAIC compilation capabilities. + + This class handles UNet models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. UNet is commonly used in + diffusion models for image generation tasks. + + Attributes: + model (nn.Module): The wrapped UNet model + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying UNet model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the UNet wrapper. + + Args: + model (nn.Module): The pipeline model containing the UNet + """ + super().__init__(model.unet) + self.model = model.unet + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + ) -> str: + """ + Export the UNet model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffVAE(QEFFBaseModel): + """ + Wrapper for Variational Autoencoder (VAE) models with ONNX export and QAIC compilation. + + This class handles VAE models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. VAE models are used in diffusion + pipelines for encoding images to latent space and decoding latents back to images. + + Attributes: + model (nn.Module): The wrapped VAE model (deep copy of original) + type (str): VAE operation type ("encoder" or "decoder") + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying VAE model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module, type: str) -> None: + """ + Initialize the VAE wrapper. + + Args: + model (nn.Module): The pipeline model containing the VAE + type (str): VAE operation type ("encoder" or "decoder") + """ + super().__init__(model) + self.model = model + + # To have different hashing for encoder/decoder + self.model.config["type"] = type + + def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the VAE decoder. + + Args: + latent_height (int): Height of latent representation (default: 32) + latent_width (int): Width of latent representation (default: 32) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # VAE decoder takes latent representation as input + example_inputs = { + "latent_sample": torch.randn(bs, 16, latent_height, latent_width), + "return_dict": False, + } + + output_names = ["sample"] + + # All dimensions except channels can be dynamic + dynamic_axes = { + "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + ) -> str: + """ + Export the VAE model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffFluxTransformerModel(QEFFBaseModel): + """ + Wrapper for Flux Transformer2D models with ONNX export and QAIC compilation capabilities. + + This class handles Flux Transformer2D models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. Flux uses a transformer-based diffusion + architecture instead of traditional UNet, with dual transformer blocks and adaptive layer + normalization (AdaLN) for conditioning. + + Attributes: + model (nn.Module): The wrapped Flux transformer model + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying Flux transformer model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the Flux transformer wrapper. + + Args: + model (nn.Module): The Flux transformer model to wrap + use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions + for better modularity and potential optimization + """ + super().__init__(model) + + def get_onnx_params( + self, + batch_size: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_length: int = constants.FLUX_ONNX_EXPORT_SEQ_LENGTH, + cl: int = constants.FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM, + ) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the Flux transformer. + + Creates example inputs for all Flux-specific inputs including hidden states, + text embeddings, timestep conditioning, and AdaLN embeddings. + + Args: + batch_size (int): Batch size for example inputs (default: FLUX_ONNX_EXPORT_BATCH_SIZE) + seq_length (int): Text sequence length (default: FLUX_ONNX_EXPORT_SEQ_LENGTH) + cl (int): Compressed latent dimension (default: FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + example_inputs = { + # Latent representation of the image + "hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32), + "encoder_hidden_states": torch.randn( + batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32 + ), + "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32), + "timestep": torch.tensor([1.0], dtype=torch.float32), + "img_ids": torch.randn(cl, 3, dtype=torch.float32), + "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32), + # AdaLN embeddings for dual transformer blocks + # Shape: [num_layers, FLUX_ADALN_DUAL_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM] + "adaln_emb": torch.randn( + self.model.config["num_layers"], + constants.FLUX_ADALN_DUAL_BLOCK_CHUNKS, + constants.FLUX_ADALN_HIDDEN_DIM, + dtype=torch.float32, + ), + # AdaLN embeddings for single transformer blocks + # Shape: [num_single_layers, FLUX_ADALN_SINGLE_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM] + "adaln_single_emb": torch.randn( + self.model.config["num_single_layers"], + constants.FLUX_ADALN_SINGLE_BLOCK_CHUNKS, + constants.FLUX_ADALN_HIDDEN_DIM, + dtype=torch.float32, + ), + # Output AdaLN embedding + # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection + "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32), + } + + output_names = ["output"] + + # Define dynamic dimensions for runtime flexibility + dynamic_axes = { + "hidden_states": {0: "batch_size", 1: "cl"}, + "encoder_hidden_states": {0: "batch_size", 1: "seq_len"}, + "pooled_projections": {0: "batch_size"}, + "timestep": {0: "steps"}, + "img_ids": {0: "cl"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Export the Flux transformer model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) + + Returns: + str: Path to the exported ONNX model + """ + + if use_onnx_subfunctions: + export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}} + + # Sort _use_default_values in config to ensure consistent hash generation during export + self.model.config["_use_default_values"].sort() + + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=False, # As weights are needed with AdaLN changes + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py new file mode 100644 index 000000000..24eb36f53 --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,218 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +from tqdm import tqdm + +from QEfficient.utils._utils import load_json +from QEfficient.utils.logging_utils import logger + + +def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_factor: int) -> int: + """ + Calculate the compressed latent dimension. + Args: + height (int): Target image height in pixels + width (int): Target image width in pixels + vae_scale_factor (int): VAE downsampling factor (typically 8 for Flux) + + Returns: + int: Compressed latent dimension (cl) for transformer input buffer allocation + """ + latent_height = height // vae_scale_factor + latent_width = width // vae_scale_factor + # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing) + cl = (latent_height * latent_width) // 4 + return cl, latent_height, latent_width + + +def config_manager(cls, config_source: Optional[str] = None): + """ + JSON-based compilation configuration manager for diffusion pipelines. + + Supports loading configuration from JSON files only. Automatically detects + model type and handles model-specific requirements. + Initialize the configuration manager. + + Args: + config_source: Path to JSON configuration file. If None, uses default config. + """ + if config_source is None: + config_source = cls.get_default_config_path() + + if not isinstance(config_source, str): + raise ValueError("config_source must be a path to JSON configuration file") + + # Direct use of load_json utility - no wrapper needed + if not os.path.exists(config_source): + raise FileNotFoundError(f"Configuration file not found: {config_source}") + + cls.custom_config = load_json(config_source) + + +def set_module_device_ids(cls): + """ + Set device IDs for each module based on the custom configuration. + + Iterates through all modules in the pipeline and assigns device IDs + from the configuration file to each module's device_ids attribute. + """ + config_modules = cls.custom_config["modules"] + for module_name, module_obj in cls.modules.items(): + module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"] + + +def compile_modules_parallel( + modules: Dict[str, Any], + config: Dict[str, Any], + specialization_updates: Dict[str, Dict[str, Any]] = None, +) -> None: + """ + Compile multiple pipeline modules in parallel using ThreadPoolExecutor. + + Args: + modules: Dictionary of module_name -> module_object pairs to compile + config: Configuration dictionary containing module-specific compilation settings + specialization_updates: Optional dictionary of module_name -> specialization_updates + to apply dynamic values (e.g., image dimensions) + """ + + def _prepare_and_compile(module_name: str, module_obj: Any) -> None: + """Prepare specializations and compile a single module.""" + specializations = config["modules"][module_name]["specializations"].copy() + compile_kwargs = config["modules"][module_name]["compilation"] + + if specialization_updates and module_name in specialization_updates: + specializations.update(specialization_updates[module_name]) + + module_obj.compile(specializations=[specializations], **compile_kwargs) + + # Execute compilations in parallel + with ThreadPoolExecutor(max_workers=len(modules)) as executor: + futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()} + + with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar: + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Compilation failed for {futures[future]}: {e}") + raise + pbar.update(1) + + +def compile_modules_sequential( + modules: Dict[str, Any], + config: Dict[str, Any], + specialization_updates: Dict[str, Dict[str, Any]] = None, +) -> None: + """ + Compile multiple pipeline modules sequentially. + + This function provides a generic way to compile diffusion pipeline modules + sequentially, which is the default behavior for backward compatibility. + + Args: + modules: Dictionary of module_name -> module_object pairs to compile + config: Configuration dictionary containing module-specific compilation settings + specialization_updates: Optional dictionary of module_name -> specialization_updates + to apply dynamic values (e.g., image dimensions) + + """ + for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"): + module_config = config["modules"] + specializations = module_config[module_name]["specializations"].copy() + compile_kwargs = module_config[module_name]["compilation"] + + # Apply dynamic specialization updates if provided + if specialization_updates and module_name in specialization_updates: + specializations.update(specialization_updates[module_name]) + + # Compile the module to QPC format + module_obj.compile(specializations=[specializations], **compile_kwargs) + + +@dataclass(frozen=True) +class ModulePerf: + """ + Data class to store performance metrics for a pipeline module. + + Attributes: + module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder') + perf: Performance metric in seconds. Can be a single float for modules that run once, + or a list of floats for modules that run multiple times (e.g., transformer steps) + """ + + module_name: str + perf: int + + +@dataclass(frozen=True) +class QEffPipelineOutput: + """ + Data class to store the output of a QEfficient diffusion pipeline. + + Attributes: + pipeline_module: List of ModulePerf objects containing performance metrics for each module + images: Generated images as either a list of PIL Images or numpy array + """ + + pipeline_module: list[ModulePerf] + images: Union[List[PIL.Image.Image], np.ndarray] + + def __repr__(self): + output_str = "=" * 60 + "\n" + output_str += "QEfficient Diffusers Pipeline Inference Report\n" + output_str += "=" * 60 + "\n\n" + + # Module-wise inference times + output_str += "Module-wise Inference Times:\n" + output_str += "-" * 60 + "\n" + + # Calculate E2E time while iterating + e2e_time = 0 + for module_perf in self.pipeline_module: + module_name = module_perf.module_name + inference_time = module_perf.perf + + # Add to E2E time + e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time + + # Format module name for display + display_name = module_name.replace("_", " ").title() + + # Handle transformer specially as it has a list of times + if isinstance(inference_time, list) and len(inference_time) > 0: + total_time = sum(inference_time) + avg_time = total_time / len(inference_time) + output_str += f" {display_name:25s} {total_time:.4f} s\n" + output_str += f" - Total steps: {len(inference_time)}\n" + output_str += f" - Average per step: {avg_time:.4f} s\n" + output_str += f" - Min step time: {min(inference_time):.4f} s\n" + output_str += f" - Max step time: {max(inference_time):.4f} s\n" + else: + # Single inference time value + output_str += f" {display_name:25s} {inference_time:.4f} s\n" + + output_str += "-" * 60 + "\n\n" + + # Print E2E time after all modules + output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n" + output_str += "=" * 60 + "\n" + + return output_str + + +# List of module name that require special handling during export +# when use_onnx_subfunctions is enabled +ONNX_SUBFUNCTION_MODULE = ["transformer"] diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index e69aebb2b..6c7173072 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -253,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with the active adapter to ONNX format. @@ -291,10 +291,10 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = example_inputs, output_names, dynamic_axes, - export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights + do_constant_folding=False, # To avoid merging adapter weights with base weights onnx_transform_kwargs={"adapter_name": self.model.active_adapter}, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + **kwargs, ) def compile( diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index 64fa3f61c..8ff8335f5 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -327,7 +327,7 @@ def _init_adapter_model(self): # load_weight to model self._load_adapter_weights_to_model() - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``. @@ -387,7 +387,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + **kwargs, ) def generate( diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 62cc71a4c..faadaba6b 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -46,6 +46,7 @@ def _get_invalid_idx_value(cls): """ if torch.onnx.is_in_onnx_export(): if cls.SUBFUNC_ENABLED: + # TODO: should not return 0 remove this if condition, it can hurt perf return 0 else: return torch.iinfo(torch.int32).max @@ -681,6 +682,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) return legacy_cache + def write_only( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + _, _, ctx_len, _ = self.key_cache[layer_idx].shape + if is_sliding_layer: + kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1) + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + return k_out, v_out + def update( self, key_states: torch.Tensor, @@ -747,3 +779,92 @@ def update( v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out + + def full_cache_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Gather + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out + + def sliding_window_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + sliding_window_len = cache_kwargs.get("sliding_window") + + # Gather + ctx_len = position_ids.shape[1] + sliding_window_len + ctx_indices = torch.arange(ctx_len)[None, None, ...] + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0) + ctx_indices += add_idx + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 5337b44f5..47059d8dc 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -188,6 +188,9 @@ # This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} +# This is for supporting different modelling classes specially written for prefill-only model +SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"} + # Define a transformers layers to QEff layers dictionary # While onboarding new models make sure to add the new layer maps to this dictionary. TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = { diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 84552aff4..3efe890b8 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +import math +import os from typing import Callable, Optional, Union import torch @@ -30,8 +32,8 @@ from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.utils import constants from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger class QEffGptOssExperts(GptOssExperts): @@ -42,8 +44,8 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) -class QEffGptOssMLP(GptOssMLP): - def alt_forward(self, hidden: torch.Tensor): +class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape T = B * S hidden = hidden.view(T, H) @@ -78,7 +80,62 @@ def alt_forward(self, hidden: torch.Tensor): up = (hidden @ W_u) + b_u # [T, I] # Apply GptOss activation with clamping - gate = gate.clamp(min=None, max=self.experts.limit) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + +class QEffPrefillOnlyGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): + if os.environ.get("NUM_FFN_BLOCKS", None) is not None: + return self.blocked_ffn_forward(hidden) + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) up = up.clamp(min=-self.experts.limit, max=self.experts.limit) # GLU activation @@ -88,6 +145,165 @@ def alt_forward(self, hidden: torch.Tensor): # Down projection down_out = (intermediate @ W_d) + b_d # [T, H] + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + # Gate and Up projections + gate = (tgb @ W_g) + b_g # [T, I] + up = (tgb @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out_block = (intermediate @ W_d) + b_d # [T, H] + + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + + wg_col_shape = W_g.shape[1] + wg_num_blocks = math.ceil(wg_col_shape / 128) + last_block_size = wg_col_shape % 128 if wg_col_shape % 128 != 0 else 128 + + intermediates = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + cur_gate = (tgb @ W_g[:, -last_block_size:]) + b_g[-last_block_size:] + cur_up = (tgb @ W_u[:, -last_block_size:]) + b_u[-last_block_size:] + else: + cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128] + cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128] + + cur_gate = cur_gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit) + cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha) + cur_intermediate = (cur_up + 1) * cur_glu + intermediates.append(cur_intermediate) + + intermediate = torch.cat(intermediates, dim=-1) + + downs = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + downs.append((intermediate @ W_d[:, -last_block_size:]) + b_d[-last_block_size:]) + else: + downs.append((intermediate @ W_d[:, i * 128 : (i + 1) * 128]) + b_d[i * 128 : (i + 1) * 128]) + + down_out_block = torch.cat(downs, dim=1) + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + # Apply routing weights and accumulate masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) expert_out += masked_down @@ -95,6 +311,8 @@ def alt_forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits + +class QEffGptOssMLP(GptOssMLP): # ------------------- Gather based, weights as activation approach --------------- def forward_weights_as_activation(self, hidden_states): bs, seq_len, _ = hidden_states.shape @@ -142,7 +360,6 @@ def forward_weights_as_activation(self, hidden_states): # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- def forward(self, hidden_states): - # print("Seperate Split, Up, Gate Projections") bs, seq_len, _ = hidden_states.shape hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) @@ -172,7 +389,7 @@ def forward(self, hidden_states): up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) # Apply activation with clamping - gate = gate.clamp(min=None, max=self.experts.limit) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) up = up.clamp(min=-self.experts.limit, max=self.experts.limit) # GLU activation @@ -404,6 +621,283 @@ def eager_attention_forward( return attn_output, attn_weights +def eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + q_block = query[:, :, qi : qi + real_q_len, :] + scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, value_states) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +def opt_eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + # Calculate block size (last block should be handled with remainder) + + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + if block_idx == 0: + kv_start_idx = 0 + else: + kv_start_idx = qi - 128 + + q_block = query[:, :, qi : qi + real_q_len, :] + if kwargs.get("sliding_window"): + k_block = key_states[:, :, kv_start_idx : qi + real_q_len, :] + v_block = value_states[:, :, kv_start_idx : qi + real_q_len, :] + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, kv_start_idx : qi + real_q_len] + else: + k_block = key_states + v_block = value_states + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, v_block) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": self.sliding_window, + } + if self.sliding_window is not None: + key_states, value_states = past_key_value.sliding_window_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + key_states, value_states = past_key_value.full_cache_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if self.sliding_window is not None: + attention_mask = sliding_mask + # positive_pos_ids = torch.where(position_ids<0, 0, position_ids) + ctx_len = position_ids.shape[1] + self.sliding_window + ctx_indices = torch.arange(ctx_len) + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx - self.sliding_window, 0) + # start_idx = torch.where(first_pos_idx>=self.sliding_window, first_pos_idx-self.sliding_window, 0) + # end_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx+position_ids.shape[1], position_ids.shape[1]+self.sliding_window) + ctx_indices += add_idx + attention_mask = attention_mask[:, :, :, ctx_indices] + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffPrefillOnlyGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + if self.sliding_window is not None: + sliding_window_len = past_key_value.sliding_window_len + short_read_idx = torch.arange(past_key_value.key_cache[self.layer_idx].shape[2]) + read_idx = short_read_idx + torch.where( + position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 + ) + # This is a trick to export with seq_len position_ids.max(), 0, read_idx) + k_cache = key_states[:, :, read_idx, :] + v_cache = value_states[:, :, read_idx, :] + else: + k_cache, v_cache = key_states, value_states + _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + if os.environ.get("ENABLE_OPT_SWA", "0") == "1": + attention_interface: Callable = opt_eager_attention_forward_blocked + else: + attention_interface: Callable = eager_attention_forward_blocked + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -429,8 +923,9 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -511,7 +1006,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores - # alth, _ = self.mlp.alt_forward(hidden_states) hidden_states = hidden_states.reshape(residual.shape) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -525,6 +1019,97 @@ def forward( return outputs +class QEffPrefillOnlyGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.max_cache_len, + sliding_window=self.config.sliding_window, + ) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + class QEffGptOssModel(GptOssModel): def forward( self, @@ -578,7 +1163,6 @@ def forward( ) hidden_states = inputs_embeds - # position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -708,15 +1292,15 @@ def forward( router_logits=outputs.router_logits, ) - def get_pkv_dynamic_axes( - self, - ): + def get_pkv_dynamic_axes(self, retain_full_kv: Optional[bool] = False, continuous_batching: Optional[bool] = False): pkv_dynamic_axes = [] for layer_type in self.config.layer_types: - if layer_type == "sliding_attention": - pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) - elif layer_type == "full_attention": - pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + if layer_type == "sliding_attention" and not retain_full_kv: + pkv_dynamic_axes.append( + {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} + ) + else: + pkv_dynamic_axes.append({0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}) return pkv_dynamic_axes def get_specializations( @@ -724,10 +1308,14 @@ def get_specializations( batch_size: int, prefill_seq_len: int, ctx_len: int, + **kwargs, ): batch_size = batch_size if batch_size else 1 - prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN - ctx_len = ctx_len if ctx_len else constants.CTX_LEN + if kwargs.get("prefill_only") and not kwargs.get("enable_chunking") and ctx_len != prefill_seq_len: + ctx_len = prefill_seq_len + logger.warning( + f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model" + ) specializations = [ { diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8edc1f3f0..008147c03 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import os import warnings from pathlib import Path from time import perf_counter @@ -37,13 +38,20 @@ get_compilation_dims, ) from QEfficient.generation.vlm_generation import VisionLanguageGeneration -from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH +from QEfficient.transformers.modeling_utils import ( + DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, + SPECIALIZED_PREFILL_ONLY_MODEL_ARCH, +) from QEfficient.transformers.models.pytorch_transforms import ( BlockedKVAttentionTransform, CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyChunkedTransform, + PrefillOnlyTransform, + RevertPrefillKeepAttentionTransform, + RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -124,21 +132,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path) - @property - def model_name(self) -> str: - """ - Get the name of the underlying HuggingFace model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - class MultimodalUtilityMixin: """ @@ -316,7 +309,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -353,7 +346,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -603,15 +596,7 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export( - self, - inputs, - output_names, - dynamic_axes, - export_dir=None, - offload_pt_weights=True, - use_onnx_subfunctions: bool = False, - ): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the vision encoder component to ONNX format. @@ -641,7 +626,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -701,21 +686,6 @@ def compile( **compiler_options, ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying vision encoder model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -771,15 +741,7 @@ def __init__(self, model, qaic_config, **kwargs): if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - def export( - self, - inputs, - output_names, - dynamic_axes, - export_dir=None, - offload_pt_weights=True, - use_onnx_subfunctions: bool = False, - ): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the language decoder component to ONNX format. @@ -809,7 +771,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -869,21 +831,6 @@ def compile( **compiler_options, ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying language decoder model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -946,21 +893,6 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.input_shapes, self.output_names = None, None - @property - def model_name(self) -> str: - """ - Get the name of the underlying multimodal model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): """ @@ -2131,21 +2063,6 @@ def cloud_ai_100_generate( ), ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying multimodal model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -2359,11 +2276,30 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + def prefill( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + if enable: + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) + + else: + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + def __init__( self, model: nn.Module, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, **kwargs, ): """ @@ -2411,6 +2347,7 @@ def __init__( # Set use_cache=True to get KV values as output during ONNX export model.config.use_cache = True + setattr(model.config, "max_seq_len_cached", max_seq_len_cached) super().__init__(model, qaic_config=qaic_config, **kwargs) self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching @@ -2423,6 +2360,7 @@ def __init__( if qaic_config: self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + self.hash_params["max_seq_len_cached"] = max_seq_len_cached # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -2437,21 +2375,6 @@ def __init__( if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - @property - def model_name(self) -> str: - """ - Get the name of the underlying Causal Language Model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -2462,6 +2385,7 @@ def from_pretrained( pretrained_model_name_or_path, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, *args, **kwargs, ): @@ -2525,7 +2449,6 @@ def from_pretrained( qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path # This is support models that should be classified to in a different auto class but transformers load them via this class - if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( model, @@ -2540,6 +2463,7 @@ def from_pretrained( continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + max_seq_len_cached=max_seq_len_cached, **kwargs, ) @@ -2555,7 +2479,56 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str: + def get_seq_len_and_handle_specialized_prefill_model( + self, prefill_seq_len: Optional[int] = None, enable_chunking=False + ) -> int: + self.hash_params["prefill_only"] = True + if enable_chunking: + self.hash_params["chunking"] = True + return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None) + if num_q_blocks is None: + block_size = 128 + if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128: + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " + f"Or set `NUM_Q_BLOCKS` ENV variable" + f"Received: prefill_seq_len={prefill_seq_len}" + ) + + num_q_blocks = prefill_seq_len // block_size + logger.warning( + f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override" + ) + os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks) + num_q_blocks = int(num_q_blocks) + + num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None) + num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks + min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks + if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0: + raise ValueError( + f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but," + "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values." + ) + + self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks + self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks + self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0") + return ( + min_seq_len + if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + ) + + def export( + self, + export_dir: Optional[str] = None, + prefill_only: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + **kwargs, + ) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2581,6 +2554,33 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) + enable_chunking = kwargs.get("enable_chunking", False) + if prefill_only: + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.prefill(enable=True, enable_chunking=enable_chunking) + self.hash_params.pop("retain_full_kv", None) + seq_len = ( + self.get_seq_len_and_handle_specialized_prefill_model( + prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking + ) + if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH + else seq_len + ) + kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len + else: + self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.hash_params.pop("prefill_only", None) + self.hash_params.pop("NUM_Q_BLOCKS", None) + self.hash_params.pop("NUM_FFN_BLOCKS", None) + self.hash_params.pop("ENABLE_OPT_SWA", None) + self.hash_params.pop("chunking", None) + if kwargs.get("retain_full_kv", False): + kv_cache_shape[2] = seq_len + self.model.config.sliding_window + self.hash_params["retain_full_kv"] = True + example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), @@ -2629,7 +2629,13 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = else: # HACK: create common function for this including above if condition code pkv_dynamic_axes = ( - self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes + self.model.get_pkv_dynamic_axes( + retain_full_kv=kwargs.get("retain_full_kv", False) + or (prefill_only and kwargs.get("enable_chunking", False)), + continuous_batching=self.continuous_batching, + ) + if hasattr(self.model, "get_pkv_dynamic_axes") + else pkv_dynamic_axes ) pkv_dynamic_axes = ( [pkv_dynamic_axes] * self.model.config.num_hidden_layers @@ -2638,7 +2644,6 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = ) for i in range(self.num_layers): - pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size" for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] @@ -2659,14 +2664,14 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names=output_names, dynamic_axes=dynamic_axes, ) - return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), offload_pt_weights=kwargs.get("offload_pt_weights", True), + prefill_only=prefill_only, ) def get_sampling_inputs_and_outputs( @@ -2756,6 +2761,7 @@ def build_prefill_specialization( batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the prefill phase. @@ -2778,11 +2784,17 @@ def build_prefill_specialization( Dict[str, Union[int, str]] A dictionary defining the prefill specialization. """ + if prefill_seq_len == 1 and self.continuous_batching: + exec_batch_size = full_batch_size + else: + exec_batch_size = 1 if self.continuous_batching else batch_size + if hasattr(self.model, "get_specializations"): spec = self.model.get_specializations( - batch_size=1 if self.continuous_batching else batch_size, + batch_size=exec_batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + **kwargs, )[0] else: spec = { @@ -2810,6 +2822,7 @@ def build_decode_specialization( kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the decode phase. @@ -2880,6 +2893,9 @@ def compile( num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, + offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -2960,6 +2976,20 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if prefill_only is None or not prefill_only: + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + if kv_cache_batch_size and not full_batch_size: + raise ValueError( + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." + ) + else: + if self.continuous_batching: + if not isinstance(kv_cache_batch_size, int): + raise ValueError( + "Please pass valid integer for kv_cache_batch_size as continuous_batching is enabled for prefill-only model" + ) # if ccl_enabled is True read Compute-Context-Length lists if self.ccl_enabled: @@ -2997,15 +3027,6 @@ def compile( if self.is_tlm: num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) - if self.continuous_batching and full_batch_size is None: - raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") - - if kv_cache_batch_size and not full_batch_size: - raise ValueError( - "KV caching requires continuous batching. Please set `full_batch_size` and " - "enable `continuous_batching=True` in `from_pretrained`." - ) - if ( self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False) @@ -3014,15 +3035,23 @@ def compile( ): raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") + if kv_cache_batch_size and prefill_only is not None and prefill_only: + logger.warning( + "kv_cache_batch_size will be ignored as prefill_only is set to True unless this is GPTOSS model" + ) + # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size # --- Specializations --- specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: + # TODO: we are handling decode-only case inside prefill call which is utterly mis-leading if self.comp_ctx_lengths_prefill is not None: # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization for i in range(0, len(self.comp_ctx_lengths_prefill)): + if prefill_only or enable_chunking: + raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL") specializations.append( self.build_prefill_specialization( prefill_seq_len=prefill_seq_len, @@ -3042,6 +3071,8 @@ def compile( batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, + prefill_only=prefill_only, + enable_chunking=enable_chunking, ) ) @@ -3069,6 +3100,7 @@ def compile( kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, ) if decode_spec: specializations.append(decode_spec) @@ -3081,7 +3113,6 @@ def compile( for i in range(self.num_layers): for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -3096,6 +3127,10 @@ def compile( aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + offload_pt_weights=offload_pt_weights, + enable_chunking=enable_chunking, + retain_full_kv=retain_full_kv, **compiler_options, ) @@ -3287,7 +3322,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -3315,7 +3350,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -3663,7 +3698,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. @@ -3691,7 +3726,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 21a867eb5..4ba6641cf 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -197,6 +197,10 @@ Starcoder2ForCausalLM, Starcoder2Model, ) +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerNorm, +) from transformers.models.whisper.modeling_whisper import ( WhisperAttention, WhisperDecoder, @@ -261,6 +265,11 @@ QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, + QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyChunkedGptOssMLP, + QEffPrefillOnlyGptOssAttention, + QEffPrefillOnlyGptOssMLP, + QEffPrefillOnlyGptOssModel, ) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, @@ -417,6 +426,10 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.models.t5.modeling_t5 import ( + QEffT5Attention, + QEffT5LayerNorm, +) from QEfficient.transformers.models.whisper.modeling_whisper import ( QEffWhisperAttention, QEffWhisperDecoder, @@ -634,6 +647,39 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +class PrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + } + + +class PrefillOnlyChunkedTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, + } + + +class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, + QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + } + + +class RevertPrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()}, + **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}, + } + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. @@ -808,6 +854,14 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} +class T5ModelTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + T5Attention: QEffT5Attention, + T5LayerNorm: QEffT5LayerNorm, + } + + class PoolingTransform: """ Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/transformers/models/t5/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py new file mode 100644 index 000000000..f54201465 --- /dev/null +++ b/QEfficient/transformers/models/t5/modeling_t5.py @@ -0,0 +1,145 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +from transformers import EncoderDecoderCache +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerNorm, +) + + +class QEffT5LayerNorm(T5LayerNorm): + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) + variance = div_first.pow(2).sum(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class QEffT5Attention(T5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + if past_key_value is not None: # This block is where the patch applies + position_bias = position_bias[:, :, -1:, :] # Added by patch + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index dfadc00ef..dc2308e99 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -5,6 +5,6 @@ # # ----------------------------------------------------------------------------- -from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers -__all__ = ["replace_transformers_quantizers"] +__all__ = ["replace_transformers_quantizers", "undo_transformers_quantizers"] diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index 49f0ad30b..3d6583f85 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -16,7 +16,6 @@ create_model_params, custom_format_warning, dump_qconfig, - export_wrapper, generate_mdp_partition_config, get_num_layers_from_config, get_num_layers_vlm, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 131a7fc26..26bae7a34 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -12,7 +12,6 @@ import subprocess import xml.etree.ElementTree as ET from dataclasses import dataclass -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import requests @@ -27,9 +26,8 @@ PreTrainedTokenizerFast, ) -from QEfficient.utils.cache import QEFF_HOME from QEfficient.utils.constants import KWARGS_INCLUSION_LIST, QEFF_MODELS_DIR, Constants, QnnConstants -from QEfficient.utils.hash_utils import create_export_hash, json_serializable +from QEfficient.utils.hash_utils import json_serializable from QEfficient.utils.logging_utils import logger @@ -532,61 +530,11 @@ def create_model_params(qeff_model, **kwargs) -> Dict: """ model_params = copy.deepcopy(kwargs) model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST} - model_params["config"] = qeff_model.model.config.to_diff_dict() model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None) model_params["applied_transform_names"] = qeff_model._transform_names() return model_params -def export_wrapper(func): - def wrapper(self, *args, **kwargs): - export_dir = kwargs.get("export_dir", None) - parent_dir = self.model_architecture or self.model_name - export_dir = Path(export_dir or (QEFF_HOME / parent_dir / self.model_name)) - - # PREPROCESSING OF PARAMETERS - - # Get the original signature - original_sig = inspect.signature(func) - - # Remove 'self' from parameters - params = list(original_sig.parameters.values())[1:] # skip 'self' - new_sig = inspect.Signature(params) - - # Bind args and kwargs to the new signature - bound_args = new_sig.bind(*args, **kwargs) - bound_args.apply_defaults() - - # Get arguments as a dictionary - all_args = bound_args.arguments - - export_hash, filtered_hash_params = create_export_hash( - model_params=self.hash_params, - output_names=all_args.get("output_names"), - dynamic_axes=all_args.get("dynamic_axes"), - export_kwargs=all_args.get("export_kwargs", None), - onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), - use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False), - ) - - export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) - kwargs["export_dir"] = export_dir - self.export_hash = export_hash - - # _EXPORT CALL - onnx_path = func(self, *args, **kwargs) - - # POST-PROCESSING - # Dump JSON file with hashed parameters - hashed_params_export_path = export_dir / "hashed_export_params.json" - create_json(hashed_params_export_path, filtered_hash_params) - logger.info("Hashed parameters exported successfully.") - - return onnx_path - - return wrapper - - def execute_command(process: str, command: str, output_file_path: Optional[str] = None): """ Executes the give command using subprocess. diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index e0b003422..613d7049a 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -144,6 +144,13 @@ def get_models_dir(): # Molmo Constants MOLMO_IMAGE_HEIGHT = 536 MOLMO_IMAGE_WIDTH = 354 +# Flux Transformer Constants +FLUX_ONNX_EXPORT_SEQ_LENGTH = 256 +FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096 +FLUX_ADALN_HIDDEN_DIM = 3072 +FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context +FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 +FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM class Constants: diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py new file mode 100644 index 000000000..638f55921 --- /dev/null +++ b/QEfficient/utils/export_utils.py @@ -0,0 +1,219 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy +import inspect +import re +import warnings +from pathlib import Path +from typing import Dict + +from QEfficient.base.onnx_transforms import CustomOpTransform, RenameFunctionOutputsTransform +from QEfficient.transformers.cache_utils import InvalidIndexProvider +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export +from QEfficient.utils.cache import QEFF_HOME +from QEfficient.utils.hash_utils import create_export_hash +from QEfficient.utils.logging_utils import logger +from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches + + +def export_wrapper(func): + """ + Decorator for export methods that orchestrates the complete export lifecycle. + + Responsibilities: + 1. Prepare export directory structure + 2. Generate reproducible hash for export configuration + 3. Setup ONNX subfunction environment (if enabled) + 4. Execute the wrapped export function + 5. Cleanup subfunction environment (if enabled) + 6. Save export metadata + + Args: + func: The export method to wrap (typically _export) + + Returns: + Wrapped function with complete export lifecycle management + """ + + def wrapper(self, *args, **kwargs): + # 1. Setup ONNX subfunctions if requested + if use_onnx_subfunctions := kwargs.pop("use_onnx_subfunctions", False): + args, kwargs = _setup_onnx_subfunctions(self, args, kwargs) + + # 2. Prepare export directory + export_dir = _prepare_export_directory(self, kwargs) + + # 3. Generate hash and finalize export directory path + export_hash, filtered_hash_params = _generate_export_hash(self, args, kwargs, func) + export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) + kwargs["export_dir"] = export_dir + self.export_hash = export_hash + + # 4. Execute the actual export + onnx_path = func(self, *args, **kwargs) + + # 5. Save export metadata + _save_export_metadata(export_dir, filtered_hash_params) + + # 6. Always cleanup subfunctions if they were setup + if use_onnx_subfunctions: + _cleanup_onnx_subfunctions(self) + + return onnx_path + + return wrapper + + +def _prepare_export_directory(qeff_model, kwargs) -> Path: + """ + Prepare and return the base export directory path. + + Args: + qeff_model: The QEff model instance + kwargs: Keyword arguments containing optional export_dir + + Returns: + Path object for the base export directory + """ + export_dir = kwargs.get("export_dir", None) + parent_dir = qeff_model.model_architecture or qeff_model.model_name + return Path(export_dir or (QEFF_HOME / parent_dir / qeff_model.model_name)) + + +def _generate_export_hash(qeff_model, args, kwargs, func): + """ + Generate export hash from model parameters and export arguments. + + The hash ensures reproducibility and prevents conflicts between + different export configurations. + + Args: + qeff_model: The QEff model instance + args: Positional arguments to the export function + kwargs: Keyword arguments to the export function + func: The export function being wrapped + + Returns: + Tuple of (export_hash: str, filtered_hash_params: dict) + """ + # Extract function signature + original_sig = inspect.signature(func) + params = list(original_sig.parameters.values())[1:] # Skip 'self' + new_sig = inspect.Signature(params) + # Bind all arguments + bound_args = new_sig.bind(*args, **kwargs) + bound_args.apply_defaults() + all_args = bound_args.arguments + + # Use the model's current configuration for hashing to ensure any post-load modifications are captured + # TODO: Replace with get_model_config property of modeling classes and remove the if-else + # Determine the config dict to use, preferring .to_diff_dict() if available + if hasattr(qeff_model.model, "config") and hasattr(qeff_model.model.config, "to_diff_dict"): + config_val = qeff_model.model.config.to_diff_dict() + elif hasattr(qeff_model.model, "model") and hasattr(qeff_model.model.model.config, "to_diff_dict"): + config_val = qeff_model.model.model.config.to_diff_dict() + else: + config_val = qeff_model.model.config + + copy_of_hash_params = copy.deepcopy(qeff_model.hash_params) + copy_of_hash_params.update( + { + "config": config_val, + } + ) + # Generate hash from relevant parameters + export_hash, filtered_hash_params = create_export_hash( + model_params=copy_of_hash_params, + output_names=all_args.get("output_names"), + dynamic_axes=all_args.get("dynamic_axes"), + export_kwargs=all_args.get("export_kwargs", None), + onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), + ) + + return export_hash, filtered_hash_params + + +def _setup_onnx_subfunctions(qeff_model, args, kwargs): + """ + Setup ONNX subfunction export environment. + + This function prepares the model and environment for exporting with + ONNX subfunctions enabled. It: + - Applies necessary torch patches + - Modifies output names for subfunction compatibility + - Adds subfunction-specific ONNX transforms + - Updates export kwargs with module classes + + Args: + qeff_model: The QEff model instance + kwargs: Export keyword arguments (modified in-place). + """ + warnings.warn( + "The subfunction feature is experimental. Please note that using compile " + "consecutively with and without subfunction may produce inconsistent results." + ) + + # Apply torch patches for subfunction support + apply_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = True + # Transform output names for subfunction compatibility + if "output_names" in kwargs: + kwargs["output_names"] = [ + re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"] + ] + else: + args = list(args) + args[1] = [re.sub("_RetainedState", "_InternalRetainedState", name) for name in args[1]] + args = tuple(args) + # Add subfunction-specific ONNX transforms + qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform) + qeff_model._onnx_transforms.append(CustomOpTransform) + + # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation + kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model) + return args, kwargs + + +def _cleanup_onnx_subfunctions(qeff_model): + """ + Cleanup ONNX subfunction export environment. + + Restores the model and environment to pre-subfunction state by: + - Undoing torch patches + - Resetting InvalidIndexProvider flag + - Restoring original ONNX transforms list + + Args: + qeff_model: The QEff model instance + + Note: + This function is called in a finally block to ensure cleanup + even if export fails. Errors during cleanup are logged but + not re-raised to avoid masking the original exception. + """ + # Undo torch patches + undo_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = False + qeff_model._onnx_transforms.remove(RenameFunctionOutputsTransform) + qeff_model._onnx_transforms.remove(CustomOpTransform) + + +def _save_export_metadata(export_dir: Path, filtered_hash_params: Dict): + """ + Save export metadata to JSON file for reproducibility. + + Args: + export_dir: Directory where the export was saved + filtered_hash_params: Dictionary of parameters used for hashing + """ + # Import here to avoid circular dependency + from QEfficient.utils._utils import create_json + + hashed_params_path = export_dir / "hashed_export_params.json" + create_json(hashed_params_path, filtered_hash_params) + logger.info("Hashed parameters exported successfully.") diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index 948b72e6a..10e6686d0 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -14,7 +14,8 @@ def json_serializable(obj): if isinstance(obj, set): - return sorted(obj) + # Convert set to a sorted list of strings for consistent hashing + return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj]) raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") @@ -55,8 +56,6 @@ def create_export_hash(**kwargs): export_params = {} export_params["output_names"] = kwargs.get("output_names") export_params["dynamic_axes"] = kwargs.get("dynamic_axes") - if kwargs.get("use_onnx_subfunctions"): - export_params["use_onnx_subfunctions"] = True export_hash_params["export_params"] = export_params export_kwargs = kwargs.get("export_kwargs") @@ -68,5 +67,4 @@ def create_export_hash(**kwargs): export_hash_params.update(onnx_transform_kwargs) if export_hash_params.get("peft_config") is not None and not isinstance(export_hash_params["peft_config"], dict): export_hash_params["peft_config"] = export_hash_params["peft_config"].to_dict() - return hash_dict_params(export_hash_params), export_hash_params diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png new file mode 100644 index 000000000..9e58da61d Binary files /dev/null and b/docs/image/girl_laughing.png differ diff --git a/examples/diffusers/flux/README.md b/examples/diffusers/flux/README.md new file mode 100644 index 000000000..2a3c1605f --- /dev/null +++ b/examples/diffusers/flux/README.md @@ -0,0 +1,243 @@ +# FLUX.1-schnell Image Generation Examples + +This directory contains examples demonstrating how to use the QEffFluxPipeline to generate images using the FLUX.1-schnell model from Black Forest Labs. + +## Overview + +FLUX.1-schnell is a fast, distilled version of the FLUX.1 text-to-image model optimized for speed with minimal quality loss. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient image generation. + +## Files + +- **`flux_1_schnell.py`** - Basic example showing simple image generation +- **`flux_1_shnell_custom.py`** - Advanced example with customization options +- **`flux_config.json`** - Configuration file for pipeline modules + +## Quick Start + +### Basic Usage + +The simplest way to generate images with FLUX.1-schnell: + +```python +from QEfficient import QEffFluxPipeline +import torch + +# Initialize pipeline +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + +# Generate image +output = pipeline( + prompt="A laughing girl", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Save image +output.images[0].save("girl_laughing.png") +``` + +Run the basic example: +```bash +python flux_1_schnell.py +``` + +## Advanced Customization + +The `flux_1_shnell_custom.py` example demonstrates several advanced features: + +### 1. Custom Model Components + +You can provide custom text encoders, transformers, and tokenizers: + +```python +pipeline = QEffFluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + text_encoder=custom_text_encoder, + transformer=custom_transformer, + tokenizer=custom_tokenizer, +) +``` + +### 2. Custom Scheduler + +Replace the default scheduler with your own: + +```python +pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) +``` + +### 3. Reduce Model Layers for Faster Inference + +Trade quality for speed by reducing transformer blocks: + +```python +original_blocks = pipeline.transformer.model.transformer_blocks +org_single_blocks = pipeline.transformer.model.single_transformer_blocks +pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +pipeline.transformer.model.config['num_layers'] = 1 +pipeline.transformer.model.config['num_single_layers'] = 1 +``` + +### 4. Pre-compile with Custom Configuration + +Compile the model separately before generation: + +```python +pipeline.compile( + compile_config="examples/diffusers/flux/flux_config.json", + height=512, + width=512, + use_onnx_subfunctions=False +) +``` + +### 5. Runtime Configuration + +Use custom configuration during generation: + +```python +output = pipeline( + prompt="A girl laughing", + custom_config_path="examples/diffusers/flux/flux_config.json", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) +``` + +Run the advanced example: +```bash +python flux_1_shnell_custom.py +``` + +## Configuration File + +The `flux_config.json` file controls compilation and execution settings for each pipeline module: + +### Module Structure + +The configuration includes four main modules: + +1. **text_encoder** (CLIP) - Encodes text prompts (77 token sequence) +2. **text_encoder_2** (T5) - Secondary text encoder (256 token sequence) +3. **transformer** - Core diffusion transformer model +4. **vae_decoder** - Decodes latents to images + +### Configuration Parameters + +Each module has three sections: + +#### Specializations +- `batch_size`: Batch size for inference +- `seq_len`: Sequence length for text encoders +- `steps`: Number of inference steps (transformer only) +- `channels`: Number of channels (VAE decoder only) + +#### Compilation +- `onnx_path`: Path to pre-exported ONNX model (null for auto-export) +- `compile_dir`: Directory for compiled artifacts (null for auto-generation) +- `mdp_ts_num_devices`: Number of devices for model data parallelism +- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication +- `convert_to_fp16`: Convert model to FP16 precision +- `aic_num_cores`: Number of AI cores to use +- `mos`: Multi-output streaming (transformer only) +- `mdts-mos`: Multi-device tensor slicing with MOS (transformer only) +- `aic-enable-depth-first`: Enable depth-first compilation (VAE only) + +#### Execute +- `device_ids`: List of device IDs to use (null for auto-selection) + +### Example Configuration Snippet + +```json +{ + "transformer": { + "specializations": { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": { + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": { + "device_ids": null + } + } +} +``` + +## Key Parameters + +### Generation Parameters + +- **`prompt`** (str): Text description of the image to generate +- **`height`** (int): Output image height in pixels (default: 1024) +- **`width`** (int): Output image width in pixels (default: 1024) +- **`guidance_scale`** (float): Classifier-free guidance scale (0.0 for schnell) +- **`num_inference_steps`** (int): Number of denoising steps (4 recommended for schnell) +- **`max_sequence_length`** (int): Maximum text sequence length (256 recommended) +- **`generator`** (torch.Generator): Random seed for reproducibility +- **`parallel_compile`** (bool): Enable parallel compilation of modules +- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export (experimental) + +### Performance Tuning + +- **Faster inference**: Reduce `num_inference_steps` or model layers +- **Better quality**: Increase `num_inference_steps` or use full model +- **Memory optimization**: Adjust `mdp_ts_num_devices` in config +- **Precision trade-offs**: Toggle `mxfp6_matmul` and `convert_to_fp16` + +## Output + +The pipeline returns an output object containing: +- `images`: List of generated PIL Image objects +- Performance metrics (timing information) + +Example output: +```python +print(output) # Displays performance information +image = output.images[0] # Access the generated image +image.save("output.png") # Save to disk +``` + +## Hardware Requirements + +- Qualcomm Cloud AI 100 accelerator +- Sufficient memory for model compilation and execution +- Multiple devices recommended for optimal transformer performance (see `mdp_ts_num_devices`) + +## Notes + +- FLUX.1-schnell is optimized for 4-step generation with `guidance_scale=0.0` +- The transformer module benefits most from multi-device parallelism +- ONNX subfunctions (`use_onnx_subfunctions=True`) is experimental and may improve compile time but is not recommended for production use +- Custom configurations allow fine-tuning for specific hardware setups + +## Troubleshooting + +- **Out of memory**: Reduce image dimensions or increase `mdp_ts_num_devices` +- **Slow compilation**: Enable `parallel_compile=True` +- **Quality issues**: Ensure using recommended parameters (4 steps, guidance_scale=0.0) +- **Device errors**: Check `device_ids` in config or set to `null` for auto-selection + +## References + +- [FLUX.1 Model Card](https://huggingface.co/black-forest-labs/FLUX.1-schnell) +- [QEfficient Documentation](../../../README.md) +- [Diffusers Pipeline Guide](../../README.md) diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py new file mode 100644 index 000000000..46f26bb6b --- /dev/null +++ b/examples/diffusers/flux/flux_1_schnell.py @@ -0,0 +1,45 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +FLUX.1-schnell Image Generation Example + +This example demonstrates how to use the QEffFluxPipeline to generate images +using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a +fast, distilled version of the FLUX.1 text-to-image model optimized for +speed with minimal quality loss. +""" + +import torch + +from QEfficient import QEffFluxPipeline + +# Initialize the FLUX.1-schnell pipeline from pretrained weights +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + +# Generate an image from a text prompt +# use_onnx_subfunctions=True enables ONNX-based optimizations for faster compilation +output = pipeline( + prompt="A laughing girl", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Extract the generated image from the output +image = output.images[0] + +# Save the generated image to disk +image.save("girl_laughing.png") + +# Print the output object (contains perf info) +print(output) diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py new file mode 100644 index 000000000..201ebe659 --- /dev/null +++ b/examples/diffusers/flux/flux_1_shnell_custom.py @@ -0,0 +1,113 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +FLUX.1 Schnell Custom Configuration Example + +This example demonstrates how to customize the FLUX.1 model with various options: +1. Custom image dimensions (height/width) +2. Custom transformer model and text encoder +3. Custom scheduler configuration +4. Reduced model layers for faster inference +5. Custom compilation settings +6. Custom runtime configuration via JSON config file + +Use this example to learn how to fine-tune FLUX.1 for your specific needs. +""" + +import torch + +from QEfficient import QEffFluxPipeline + +# ============================================================================ +# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS +# ============================================================================ + +# Option 1: Basic initialization with default parameters +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") +# Option 2: Advanced initialization with custom modules +# Uncomment and modify to use your own custom components: +# +# pipeline = QEffFluxPipeline.from_pretrained( +# "black-forest-labs/FLUX.1-schnell", +# text_encoder=custom_text_encoder, # Your custom CLIP text encoder +# transformer=custom_transformer, # Your custom transformer model +# tokenizer=custom_tokenizer, # Your custom tokenizer +# ) + +# ============================================================================ +# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION +# ============================================================================ +# Uncomment to use a custom scheduler (e.g., different sampling methods): +# +# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Reduce the number of transformer blocks to speed up image generation. +# +# Trade-off: Faster inference but potentially lower image quality +# Use case: Quick testing, prototyping, or when speed is critical +# +# Uncomment the following lines to use only the first transformer block: +# +# original_blocks = pipeline.transformer.model.transformer_blocks +# org_single_blocks = pipeline.transformer.model.single_transformer_blocks +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +# pipeline.transformer.model.config['num_layers'] = 1 +# pipeline.transformer.model.config['num_single_layers'] = 1 + +# ============================================================================ +# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION +# ============================================================================ +# Pre-compile the model for optimized performance on target hardware. +# +# When to use: +# - When you want to compile the model separately before generation +# - When you need to skip image generation and only prepare the model +# +# NOTE-1: If compile_config is not specified, the default configuration from +# QEfficient/diffusers/pipelines/flux/flux_config.json will be used +# +# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations (Experimental so not recommended) +# This feature improves export performance by breaking down the model into smaller, +# more manageable ONNX functions, which can lead to improve compile time. +# Uncomment to compile with a custom configuration: +# pipeline.compile( +# compile_config="examples/diffusers/flux/flux_config.json", +# height=512, +# width=512, +# use_onnx_subfunctions=False +# ) + +# ============================================================================ +# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION +# ============================================================================ +# Generate an image using the configured pipeline. +# +# Note: Use of custom_config_path provides flexibility to set device_ids for each +# module, so you can skip the separate pipeline.compile() step. + +output = pipeline( + prompt="A laughing girl", + custom_config_path="examples/diffusers/flux/flux_config.json", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +image = output.images[0] +# Save the generated image to disk +image.save("laughing_girl.png") +print(output) diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json new file mode 100644 index 000000000..73b92265f --- /dev/null +++ b/examples/diffusers/flux/flux_config.json @@ -0,0 +1,99 @@ +{ + "description": "Default configuration for Flux pipeline", + + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only": true + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + } + } +} diff --git a/examples/disagg_serving/README.md b/examples/disagg_serving/README.md new file mode 100644 index 000000000..fcf665357 --- /dev/null +++ b/examples/disagg_serving/README.md @@ -0,0 +1,31 @@ +# We should be using disaggragate serving for GPTOSS model for best performance + - GPT-OSS model has 128/4 for 120b and 32/4 ratio of total_experts/experts_per_tok + - We use read all experts only once always strategy in prefill-only model + - And we treat weights activtions meaning read only chosen experts for decode-only model + +# Prefill-only model +## Blocking default behviour when `prefill_only=True` in compile API + - NUM_Q_BLOCKS= set number of Q blocks in attention + - NUM_FFN_BLOCKS= set number of blocks in FFN + - ENABLE_OPT_SWA="0" or "1" to enable/disable optimized SWA. when enabled we will be using only valid KVs for given block in Attention reducing MACs + - prefix_caching is not supported with this mode + +## Chunking pass `enable_chunking=True` and `prefill_only=True` in compile API + - Optimized SWA i.e. reading only valid KV as per diagonal attention mask is enabled for this version by default + - This model can be used for prefix_caching by passing `kv_cache_batch_size=` in compile API + +# Decode-only model +## Retain Sliding window length of KV for sliding window layers, default behavour when `prefill_seq_len=1` in compile API + - This reduces the amount of DDR used by the model + - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed +## Full KV for sliding window layers pass `retain_full_kv=True` along with `prefill_seq_len=1` in compile API + - This uses higher DDR as we are retaining ctx_len KV even for sliding window layers but will be reading only sliding window len kv in attention + - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed + - This is enabled for the usecase of multi-turn chat, where we will be running prefill-> decode and then use cache of prefill as well as decode combined to again run prefill, so we want to retain full KV for sliding window layers + + +NOTE: +* decode-only model currently fails compilation with `use_onnx_subfunctions=True` so avoid using it +* 120B model needs NPI, there are two versions of NPI one with and without subfunction both are uploaded here, pass it as `node_precision_info=` +* It is advised to use `use_onnx_subfunctions=True` with prefill-only model, otherwise the compilation times are too high, with this the model is supposed to export and fail during compile as it needs assert sdk, so user is supposed to run this compilation manually by pasting the command printed in the error + diff --git a/examples/disagg_serving/gpt_oss_disagg_mode.py b/examples/disagg_serving/gpt_oss_disagg_mode.py new file mode 100644 index 000000000..fd0d5b045 --- /dev/null +++ b/examples/disagg_serving/gpt_oss_disagg_mode.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +all_outputs = [] +# Run prefill +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 256 +CTX_LEN = 256 +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +# Initialize variables specific to request +# Calculate the max generation length. +max_gen_len = CTX_LEN - position_ids.max() +generation_len = max_gen_len + + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +config = qeff_model.model.config +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +for i in range(config.num_hidden_layers): + cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) +inputs["past_key_values"] = past_key_values + + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, +) +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + use_onnx_subfunctions=True, +) + +prefill_session = QAICInferenceSession(prefill_qpc_path) + +logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) +prefill_session.set_buffers({"logits": logits_out_placeholder}) +inputs.pop("past_key_values") +inputs = {k: v.detach().numpy() for k, v in inputs.items()} +st = time.time() +qpc_out = prefill_session.run(inputs) +print(f"time for prefill_run={time.time() - st} sec\n") + +decode_session = QAICInferenceSession(decode_qpc_path) +decode_session.set_buffers({"logits": logits_out_placeholder}) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +print("pos_id for decodee", decode_inputs["position_ids"]) + +all_outputs.append(decode_inputs["input_ids"][0][0]) +for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] +) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +st = time.time() +for i in range(generation_len - 2): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + +print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") +print(all_outputs) +print(tokenizer.decode(all_outputs)) diff --git a/examples/disagg_serving/subfunction_120b_npi.yaml b/examples/disagg_serving/subfunction_120b_npi.yaml new file mode 100644 index 000000000..762703d58 --- /dev/null +++ b/examples/disagg_serving/subfunction_120b_npi.yaml @@ -0,0 +1,27 @@ +FP32NodeInstanceNames: + - CustomRMSNorm_58 + - onnx::Shape_1033777 + - CustomRMSNorm_349 + - hidden.127 + - CustomRMSNorm_27448 + - onnx::Shape_1066066 + - CustomRMSNorm_27709 + - hidden.131 + - CustomRMSNorm_54808 + - onnx::Shape_878 + - CustomRMSNorm_55105 + - hidden + - hidden_states.259 + - Add_348 + - Add_347 + - onnx::Add_1034099 + - hidden_states.267 + - Add_27708 + - onnx::Add_1066358 + - Add_27707 + - hidden_states.3 + - Add_55104 + - onnx::Add_1209 + - Add_55103 + - /model/norm/CustomRMSNorm + - /model/norm/CustomRMSNorm_output_0 \ No newline at end of file diff --git a/examples/disagg_serving/without_subfunc_npi_120b.yaml b/examples/disagg_serving/without_subfunc_npi_120b.yaml new file mode 100644 index 000000000..ec6cf034f --- /dev/null +++ b/examples/disagg_serving/without_subfunc_npi_120b.yaml @@ -0,0 +1,148 @@ +FP32NodeInstanceNames: + - /model/layers.0/Add_1_output_0 + - /model/layers.0/Add_output_0 + - /model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/Add_1_output_0 + - /model/layers.1/Add_output_0 + - /model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/Add_1_output_0 + - /model/layers.10/Add_output_0 + - /model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/Add_1_output_0 + - /model/layers.11/Add_output_0 + - /model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/Add_1_output_0 + - /model/layers.12/Add_output_0 + - /model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/Add_1_output_0 + - /model/layers.13/Add_output_0 + - /model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/Add_1_output_0 + - /model/layers.14/Add_output_0 + - /model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/Add_1_output_0 + - /model/layers.15/Add_output_0 + - /model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/Add_1_output_0 + - /model/layers.16/Add_output_0 + - /model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/Add_1_output_0 + - /model/layers.17/Add_output_0 + - /model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/Add_1_output_0 + - /model/layers.18/Add_output_0 + - /model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/Add_1_output_0 + - /model/layers.19/Add_output_0 + - /model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/Add_1_output_0 + - /model/layers.2/Add_output_0 + - /model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/Add_1_output_0 + - /model/layers.20/Add_output_0 + - /model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/Add_1_output_0 + - /model/layers.21/Add_output_0 + - /model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/Add_1_output_0 + - /model/layers.22/Add_output_0 + - /model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/Add_1_output_0 + - /model/layers.23/Add_output_0 + - /model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/Add_1_output_0 + - /model/layers.24/Add_output_0 + - /model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/Add_1_output_0 + - /model/layers.25/Add_output_0 + - /model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/Add_1_output_0 + - /model/layers.26/Add_output_0 + - /model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/Add_1_output_0 + - /model/layers.27/Add_output_0 + - /model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/Add_1_output_0 + - /model/layers.28/Add_output_0 + - /model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/Add_1_output_0 + - /model/layers.29/Add_output_0 + - /model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/Add_1_output_0 + - /model/layers.3/Add_output_0 + - /model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/Add_1_output_0 + - /model/layers.30/Add_output_0 + - /model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/Add_1_output_0 + - /model/layers.31/Add_output_0 + - /model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/Add_1_output_0 + - /model/layers.32/Add_output_0 + - /model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/Add_1_output_0 + - /model/layers.33/Add_output_0 + - /model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/Add_1_output_0 + - /model/layers.34/Add_output_0 + - /model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/Add_1_output_0 + - /model/layers.35/Add_output_0 + - /model/norm/Add_output_0 + - /model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/Add_1_output_0 + - /model/layers.4/Add_output_0 + - /model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/Add_1_output_0 + - /model/layers.5/Add_output_0 + - /model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/Add_1_output_0 + - /model/layers.6/Add_output_0 + - /model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/Add_1_output_0 + - /model/layers.7/Add_output_0 + - /model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/Add_1_output_0 + - /model/layers.8/Add_output_0 + - /model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/Add_1_output_0 + - /model/layers.9/Add_output_0 + - /model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/norm/CustomRMSNorm_output_0 + \ No newline at end of file diff --git a/examples/gpt_oss_disagg_mode_with_chunking.py b/examples/gpt_oss_disagg_mode_with_chunking.py new file mode 100644 index 000000000..363e2806c --- /dev/null +++ b/examples/gpt_oss_disagg_mode_with_chunking.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +# Run prefill +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 128 * 3 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, +) + + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 +# prefill_qpc_path = "provide path here" +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +decode_session = QAICInferenceSession(decode_qpc_path) +prefill_session = QAICInferenceSession(prefill_qpc_path) + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +st = time.time() +for i in range(generation_len - 2): + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/pyproject.toml b/pyproject.toml index 8e179ab4a..77322d8df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,13 +20,13 @@ classifiers = [ requires-python = ">=3.8,<3.11" dependencies = [ "transformers==4.55.0", + "diffusers== 0.35.1", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", - "peft==0.13.2", + "peft==0.17.0", "datasets==2.20.0", "fsspec==2023.6.0", "multidict==6.0.4", - "urllib3<2", "sentencepiece==0.2.0", "onnx==1.18.0", "onnxruntime==1.22", diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 134770638..8f95c1d98 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -22,6 +22,7 @@ pipeline { . preflight_qeff/bin/activate && pip install --upgrade pip setuptools && pip install .[test] && + pip install .[diffusers] && pip install junitparser pytest-xdist && pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs @@ -41,7 +42,7 @@ pipeline { mkdir -p $PWD/Non_cli_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic && - pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log1.xml && + pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log1.xml && junitparser merge tests/tests_log1.xml tests/tests_log.xml && deactivate" ''' @@ -69,7 +70,7 @@ pipeline { } stage('QAIC MultiModal Tests') { steps { - timeout(time: 60, unit: 'MINUTES') { + timeout(time: 120, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && @@ -86,7 +87,7 @@ pipeline { } stage('Inference Tests') { steps { - timeout(time: 60, unit: 'MINUTES') { + timeout(time: 120, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " #source /qnn_sdk/bin/envsetup.sh && @@ -162,7 +163,7 @@ pipeline { // } stage('Finetune CLI Tests') { steps { - timeout(time: 5, unit: 'MINUTES') { + timeout(time: 20, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && diff --git a/tests/base/test_export_memory_offload.py b/tests/base/test_export_memory_offload.py index d1b7a4653..f63b18f1a 100644 --- a/tests/base/test_export_memory_offload.py +++ b/tests/base/test_export_memory_offload.py @@ -27,7 +27,7 @@ @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/diffusers/diffusers_utils.py b/tests/diffusers/diffusers_utils.py new file mode 100644 index 000000000..305116c03 --- /dev/null +++ b/tests/diffusers/diffusers_utils.py @@ -0,0 +1,175 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Common utilities for diffusion pipeline testing. +Provides essential functions for MAD validation, image validation +hash verification, and other testing utilities. +""" + +import os +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch +from PIL import Image + + +class DiffusersTestUtils: + """Essential utilities for diffusion pipeline testing""" + + @staticmethod + def validate_image_generation( + image: Image.Image, expected_size: Tuple[int, int], min_variance: float = 1.0 + ) -> Dict[str, Any]: + """ + Validate generated image properties. + Args: + image: Generated PIL Image + expected_size: Expected (width, height) tuple + min_variance: Minimum pixel variance to ensure image is not blank + + Returns: + Dict containing validation results + Raises: + AssertionError: If image validation fails + """ + # Basic image validation + assert isinstance(image, Image.Image), f"Expected PIL Image, got {type(image)}" + assert image.size == expected_size, f"Expected size {expected_size}, got {image.size}" + assert image.mode in ["RGB", "RGBA"], f"Unexpected image mode: {image.mode}" + + # Variance check (ensure image is not blank) + img_array = np.array(image) + image_variance = float(img_array.std()) + assert image_variance > min_variance, f"Generated image appears blank (variance: {image_variance:.2f})" + + return { + "size": image.size, + "mode": image.mode, + "variance": image_variance, + "mean_pixel_value": float(img_array.mean()), + "min_pixel": int(img_array.min()), + "max_pixel": int(img_array.max()), + "valid": True, + } + + @staticmethod + def check_file_exists(file_path: str, file_type: str = "file") -> bool: + """ + Check if file exists and log result. + Args: + file_path: Path to check + file_type: Description of file type for logging + Returns: + bool: True if file exists + """ + exists = os.path.exists(file_path) + status = "āœ…" if exists else "āŒ" + print(f"{status} {file_type}: {file_path}") + return exists + + @staticmethod + def print_test_header(title: str, config: Dict[str, Any]) -> None: + """ + Print formatted test header with configuration details. + + Args: + title: Test title + config: Test configuration dictionary + """ + print(f"\n{'=' * 80}") + print(f"{title}") + print(f"{'=' * 80}") + + if "model_setup" in config: + setup = config["model_setup"] + for k, v in setup.items(): + print(f"{k} : {v}") + + if "functional_testing" in config: + func = config["functional_testing"] + print(f"Test Prompt: {func.get('test_prompt', 'N/A')}") + print(f"Inference Steps: {func.get('num_inference_steps', 'N/A')}") + print(f"Guidance Scale: {func.get('guidance_scale', 'N/A')}") + + print(f"{'=' * 80}") + + +class MADValidator: + """Specialized class for MAD validation - always enabled, always reports, always fails on exceed""" + + def __init__(self, tolerances: Dict[str, float] = None): + """ + Initialize MAD validator. + MAD validation is always enabled, always reports values, and always fails if tolerance is exceeded. + + Args: + tolerances: Dictionary of module_name -> tolerance mappings + """ + self.tolerances = tolerances + self.results = {} + + def calculate_mad( + self, tensor1: Union[torch.Tensor, np.ndarray], tensor2: Union[torch.Tensor, np.ndarray] + ) -> float: + """ + Calculate Max Absolute Deviation between two tensors. + + Args: + tensor1: First tensor (PyTorch or NumPy) + tensor2: Second tensor (PyTorch or NumPy) + + Returns: + float: Maximum absolute difference between tensors + """ + if isinstance(tensor1, torch.Tensor): + tensor1 = tensor1.detach().numpy() + if isinstance(tensor2, torch.Tensor): + tensor2 = tensor2.detach().numpy() + + return float(np.max(np.abs(tensor1 - tensor2))) + + def validate_module_mad( + self, + pytorch_output: Union[torch.Tensor, np.ndarray], + qaic_output: Union[torch.Tensor, np.ndarray], + module_name: str, + step_info: str = "", + ) -> bool: + """ + Validate MAD for a specific module. + Always validates, always reports, always fails if tolerance exceeded. + + Args: + pytorch_output: PyTorch reference output + qaic_output: QAIC inference output + module_name: Name of the module + step_info: Additional step information for logging + + Returns: + bool: True if validation passed + + Raises: + AssertionError: If MAD exceeds tolerance + """ + mad_value = self.calculate_mad(pytorch_output, qaic_output) + + # Always report MAD value + step_str = f" {step_info}" if step_info else "" + print(f"šŸ” {module_name.upper()} MAD{step_str}: {mad_value:.8f}") + + # Always validate - fail if tolerance exceeded + tolerance = self.tolerances.get(module_name, 1e-2) + if mad_value > tolerance: + raise AssertionError(f"{module_name} MAD {mad_value:.6f} exceeds tolerance {tolerance:.6f}") + + # Store result + if module_name not in self.results: + self.results[module_name] = [] + self.results[module_name].append({"mad": mad_value, "step_info": step_info, "tolerance": tolerance}) + return True diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json new file mode 100644 index 000000000..7d0c17d55 --- /dev/null +++ b/tests/diffusers/flux_test_config.json @@ -0,0 +1,123 @@ +{ + "model_setup": { + "height": 256, + "width": 256, + "num_transformer_layers": 2, + "num_single_layers": 2, + "use_onnx_subfunctions": false + }, + "mad_validation": { + "tolerances": { + "clip_text_encoder": 0.1, + "t5_text_encoder": 5.5, + "transformer": 2.0, + "vae_decoder": 1.0 + } + }, + "pipeline_params": { + "test_prompt": "A cat holding a sign that says hello world", + "num_inference_steps": 2, + "guidance_scale": 0.0, + "max_sequence_length": 256, + "validate_gen_img": true, + "min_image_variance": 1.0, + "custom_config_path": null + }, + "validation_checks": { + "image_generation": true, + "onnx_export": true, + "compilation": true + }, + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "aic-enable-depth-first": true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + } + } + +} diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py new file mode 100644 index 000000000..6f4396a20 --- /dev/null +++ b/tests/diffusers/test_flux.py @@ -0,0 +1,448 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import pytest +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from QEfficient import QEffFluxPipeline +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ModulePerf, + QEffPipelineOutput, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils._utils import load_json +from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator + +# Test Configuration for 256x256 resolution with 2 layers # update mad tolerance +CONFIG_PATH = "tests/diffusers/flux_test_config.json" +INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) + + +def flux_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height: int = 256, + width: int = 256, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + parallel_compile: bool = False, + mad_tolerances: Dict[str, float] = None, +): + """ + Pipeline call function that replicates the exact flow of pipeline_flux.py.__call__() + while adding comprehensive MAD validation at each step. + + This function follows the EXACT same structure as QEffFluxPipeline.__call__() + but adds MAD validation hooks throughout the process. + """ + # Initialize MAD validator + mad_validator = MADValidator(tolerances=mad_tolerances) + + device = "cpu" + + # Step 1: Load configuration, compile models + pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width) + + # Set device IDs for all modules based on configuration + set_module_device_ids(pipeline) + + # Validate all inputs + pipeline.model.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + # Set pipeline attributes + pipeline._guidance_scale = guidance_scale + pipeline._interrupt = False + batch_size = INITIAL_TEST_CONFIG["modules"]["transformer"]["specializations"]["batch_size"] + + # Step 3: Encode prompts with both text encoders + # Use pipeline's encode_prompt method + (t5_qaic_prompt_embeds, clip_qaic_pooled_prompt_embeds, text_ids, text_encoder_perf) = pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + (t5_torch_prompt_embeds, clip_torch_pooled_prompt_embeds, text_ids) = pytorch_pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + # Deactivate text encoder qpc sessions + pipeline.text_encoder.qpc_session.deactivate() + pipeline.text_encoder_2.qpc_session.deactivate() + + # MAD Validation for Text Encoders + print("šŸ” Performing MAD validation for text encoders...") + mad_validator.validate_module_mad( + clip_qaic_pooled_prompt_embeds, clip_torch_pooled_prompt_embeds, module_name="clip_text_encoder" + ) + mad_validator.validate_module_mad(t5_torch_prompt_embeds, t5_qaic_prompt_embeds, "t5_text_encoder") + + # Step 4: Prepare timesteps for denoising + timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + pipeline._num_timesteps = len(timesteps) + + # Step 5: Prepare initial latents + num_channels_latents = pipeline.transformer.model.config.in_channels // 4 + latents, latent_image_ids = pipeline.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + t5_qaic_prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Step 6: Initialize transformer inference session + if pipeline.transformer.qpc_session is None: + pipeline.transformer.qpc_session = QAICInferenceSession( + str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids + ) + + # Calculate compressed latent dimension (cl) for transformer buffer allocation + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + cl, _, _ = calculate_compressed_latent_dimension(height, width, pipeline.model.vae_scale_factor) + + # Allocate output buffer for transformer + output_buffer = { + "output": np.random.rand(batch_size, cl, pipeline.transformer.model.config.in_channels).astype(np.float32), + } + pipeline.transformer.qpc_session.set_buffers(output_buffer) + + transformer_perf = [] + pipeline.scheduler.set_begin_index(0) + + # Step 7: Denoising loop + with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if pipeline._interrupt: + continue + + # Prepare timestep embedding + timestep = t.expand(latents.shape[0]).to(latents.dtype) + temb = pipeline.transformer.model.time_text_embed(timestep, clip_qaic_pooled_prompt_embeds) + + # Compute AdaLN embeddings for dual transformer blocks + adaln_emb = [] + for block_idx in range(len(pipeline.transformer.model.transformer_blocks)): + block = pipeline.transformer.model.transformer_blocks[block_idx] + f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1) + f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1) + adaln_emb.append(torch.cat(list(f1) + list(f2))) + adaln_dual_emb = torch.stack(adaln_emb) + + # Compute AdaLN embeddings for single transformer blocks + adaln_emb = [] + for block_idx in range(len(pipeline.transformer.model.single_transformer_blocks)): + block = pipeline.transformer.model.single_transformer_blocks[block_idx] + f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1) + adaln_emb.append(torch.cat(list(f1))) + adaln_single_emb = torch.stack(adaln_emb) + + # Compute output AdaLN embedding + temp = pipeline.transformer.model.norm_out + adaln_out = temp.linear(temp.silu(temb)) + + # Normalize timestep to [0, 1] range + timestep = timestep / 1000 + + # Prepare all inputs for transformer inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": t5_qaic_prompt_embeds.detach().numpy(), + "pooled_projections": clip_qaic_pooled_prompt_embeds.detach().numpy(), + "timestep": timestep.detach().numpy(), + "img_ids": latent_image_ids.detach().numpy(), + "txt_ids": text_ids.detach().numpy(), + "adaln_emb": adaln_dual_emb.detach().numpy(), + "adaln_single_emb": adaln_single_emb.detach().numpy(), + "adaln_out": adaln_out.detach().numpy(), + } + + # MAD Validation for Transformer - PyTorch reference inference + noise_pred_torch = pytorch_pipeline.transformer( + hidden_states=latents, + encoder_hidden_states=t5_torch_prompt_embeds, + pooled_projections=clip_torch_pooled_prompt_embeds, + timestep=torch.tensor(timestep), + img_ids=latent_image_ids, + txt_ids=text_ids, + return_dict=False, + )[0] + + # Run transformer inference and measure time + start_transformer_step_time = time.time() + outputs = pipeline.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.time() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + noise_pred = torch.from_numpy(outputs["output"]) + + # Transformer MAD validation + mad_validator.validate_module_mad( + noise_pred_torch.detach().cpu().numpy(), + outputs["output"], + "transformer", + f"step {i} (t={t.item():.1f})", + ) + + # Update latents using scheduler + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Handle dtype mismatch + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + # Step 8: Decode latents to images + if output_type == "latent": + image = latents + vae_decode_perf = 0.0 # No VAE decoding for latent output + else: + # Unpack and denormalize latents + latents = pipeline.model._unpack_latents(latents, height, width, pipeline.model.vae_scale_factor) + + # Denormalize latents + latents = (latents / pipeline.vae_decode.model.scaling_factor) + pipeline.vae_decode.model.shift_factor + # Initialize VAE decoder inference session + if pipeline.vae_decode.qpc_session is None: + pipeline.vae_decode.qpc_session = QAICInferenceSession( + str(pipeline.vae_decode.qpc_path), device_ids=pipeline.vae_decode.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.float32)} + pipeline.vae_decode.qpc_session.set_buffers(output_buffer) + + # MAD Validation for VAE + # PyTorch reference inference + image_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0] + + # Run VAE decoder inference and measure time + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.time() + image = pipeline.vae_decode.qpc_session.run(inputs) + end_decode_time = time.time() + vae_decode_perf = end_decode_time - start_decode_time + + # VAE MAD validation + mad_validator.validate_module_mad(image_torch.detach().cpu().numpy(), image["sample"], "vae_decoder") + + # Post-process image + image_tensor = torch.from_numpy(image["sample"]) + image = pipeline.model.image_processor.postprocess(image_tensor, output_type=output_type) + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]), + ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]), + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decode_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=image, + ) + + +@pytest.fixture(scope="session") +def flux_pipeline(): + """Setup compiled Flux pipeline for testing""" + config = INITIAL_TEST_CONFIG["model_setup"] + + pipeline = QEffFluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + use_onnx_subfunctions=config["use_onnx_subfunctions"], + ) + + # Reduce to 2 layers for testing + original_blocks = pipeline.transformer.model.transformer_blocks + org_single_blocks = pipeline.transformer.model.single_transformer_blocks + + pipeline.transformer.model.config["num_layers"] = config["num_transformer_layers"] + pipeline.transformer.model.config["num_single_layers"] = config["num_single_layers"] + pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pipeline.transformer.model.config["num_layers"])] + ) + pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList( + [org_single_blocks[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])] + ) + + ### Pytorch pipeline + pytorch_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + original_blocks_pt = pytorch_pipeline.transformer.transformer_blocks + org_single_blocks_pt = pytorch_pipeline.transformer.single_transformer_blocks + pytorch_pipeline.transformer.transformer_blocks = torch.nn.ModuleList( + [original_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_layers"])] + ) + pytorch_pipeline.transformer.single_transformer_blocks = torch.nn.ModuleList( + [org_single_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])] + ) + return pipeline, pytorch_pipeline + + +@pytest.mark.diffusion_models +@pytest.mark.on_qaic +def test_flux_pipeline(flux_pipeline): + """ + Comprehensive Flux pipeline test that follows the exact same flow as pipeline_flux.py: + - 256x256 resolution - 2 transformer layers + - MAD validation + - Functional image generation test + - Export/compilation checks + - Returns QEffPipelineOutput with performance metrics + """ + pipeline, pytorch_pipeline = flux_pipeline + config = INITIAL_TEST_CONFIG + + # Print test header + DiffusersTestUtils.print_test_header( + f"FLUX PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_transformer_layers']} Layers", + config, + ) + + # Test parameters + test_prompt = config["pipeline_params"]["test_prompt"] + num_inference_steps = config["pipeline_params"]["num_inference_steps"] + guidance_scale = config["pipeline_params"]["guidance_scale"] + max_sequence_length = config["pipeline_params"]["max_sequence_length"] + + # Generate with MAD validation + generator = torch.manual_seed(42) + start_time = time.time() + + try: + # Run the pipeline with integrated MAD validation (follows exact pipeline flow) + result = flux_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height=config["model_setup"]["height"], + width=config["model_setup"]["width"], + prompt=test_prompt, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + max_sequence_length=max_sequence_length, + custom_config_path=CONFIG_PATH, + generator=generator, + mad_tolerances=config["mad_validation"]["tolerances"], + parallel_compile=True, + return_dict=True, + ) + + execution_time = time.time() - start_time + + # Validate image generation + if config["pipeline_params"]["validate_gen_img"]: + assert result is not None, "Pipeline returned None" + assert hasattr(result, "images"), "Result missing 'images' attribute" + assert len(result.images) > 0, "No images generated" + + generated_image = result.images[0] + expected_size = (config["model_setup"]["height"], config["model_setup"]["width"]) + # Validate image properties using utilities + image_validation = DiffusersTestUtils.validate_image_generation( + generated_image, expected_size, config["pipeline_params"]["min_image_variance"] + ) + + print("\nāœ… IMAGE VALIDATION PASSED") + print(f" - Size: {image_validation['size']}") + print(f" - Mode: {image_validation['mode']}") + print(f" - Variance: {image_validation['variance']:.2f}") + print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}") + file_path = "test_flux_256x256_2layers.png" + # Save test image + generated_image.save(file_path) + + if os.path.exists(file_path): + print(f"Image saved successfully at: {file_path}") + else: + print("Image was not saved.") + + if config["validation_checks"]["onnx_export"]: + # Check if ONNX files exist (basic check) + print("\nšŸ” ONNX Export Validation:") + for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path: + DiffusersTestUtils.check_file_exists(str(module_obj.onnx_path), f"{module_name} ONNX") + + if config["validation_checks"]["compilation"]: + # Check if QPC files exist (basic check) + print("\nšŸ” Compilation Validation:") + for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path: + DiffusersTestUtils.check_file_exists(str(module_obj.qpc_path), f"{module_name} QPC") + + # Print test summary using utilities + print(f"\nTotal execution time: {execution_time:.4f}s") + except Exception as e: + print(f"\nTEST FAILED: {e}") + raise + + +if __name__ == "__main__": + # This allows running the test file directly for debugging + pytest.main([__file__, "-v", "-s", "-m", "flux"]) +# pytest tests/diffusers/test_flux.py -m flux -v -s --tb=short diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index 00a4216b7..46b33c60b 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -222,7 +222,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( # export start = perf_counter() - qeff_model.export(export_dir=tmp_path) + onnx_path = qeff_model.export(export_dir=tmp_path) end = perf_counter() export_time_0 = end - start model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash) @@ -237,7 +237,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( assert export_time_1 < export_time_0 # test compile - qeff_model.compile(prefill_seq_len=32, ctx_len=64) + qeff_model.compile(onnx_path=onnx_path, prefill_seq_len=32, ctx_len=64) assert Path(qeff_model.qpc_path).is_dir() assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py index cc94467db..c3bb2f140 100644 --- a/tests/peft/test_peft_model.py +++ b/tests/peft/test_peft_model.py @@ -178,9 +178,9 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path): _, lora_model = create_peft_model(base_config, adapter_config) qeff_model = QEffAutoPeftModelForCausalLM(lora_model) - qeff_model.export(tmp_path) + onnx_path = qeff_model.export(tmp_path) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_0 = end - start @@ -197,7 +197,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con ) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_1 = end - start assert compile_time_1 < 0.01 * compile_time_0 diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py new file mode 100644 index 000000000..6358940df --- /dev/null +++ b/tests/transformers/models/test_disagg_mode.py @@ -0,0 +1,192 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers + +model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 + +prompt2 = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +prompt1 = "Once upon a time" + +prompts = [prompt1, prompt2] + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 256 + CTX_LEN = 256 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + qeff_out = qeff_model.model(**inputs) + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + ) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + qpc_out = prefill_session.run(inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 + + +@pytest.mark.skip(reason="no way of currently testing this without the assert sdk") +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill_chunked(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 128 + CTX_LEN = 128 * 3 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True, enable_chunking=True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = CTX_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + + qeff_out = qeff_model.model(**chunk_inputs) + inputs["past_key_values"] = qeff_out["past_key_values"] + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + ) + prefill_session = QAICInferenceSession(prefill_qpc_path) + prefill_session.skip_buffers( + [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] + ) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2 diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index 0810ac6ba..72477d56a 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -14,10 +14,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.hash_utils import hash_dict_params -configs = [ +test_configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params ("gpt2", 256, 2, 4, 128, 512, 127, {}), ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), @@ -36,30 +37,43 @@ ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] -configs = [ - AutoConfig.for_model( - model_name, - max_position_embeddings=max_position_embeddings, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - vocab_size=vocab_size, - **additional_params, - ) - for ( - model_name, - max_position_embeddings, - num_hidden_layers, - num_attention_heads, - hidden_size, - intermediate_size, - vocab_size, - additional_params, - ) in configs +test_prefill_only_specialized_models_configs = [ + ("gpt_oss", 256, 2, 2, 32, 32, 127, {"num_key_value_heads": 2}), ] + + +def get_auto_config_from_test_config(configs): + auto_configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs + ] + return auto_configs + + +configs = get_auto_config_from_test_config(test_configs) config_ids = [x.model_type for x in configs] +prefill_only_configs = get_auto_config_from_test_config(test_prefill_only_specialized_models_configs) +prefill_only_config_ids = [x.model_type for x in prefill_only_configs] + model_kwargs = {"attn_implementation": "eager"} @@ -144,20 +158,21 @@ def test_causal_lm_export_and_hash(config, cb, tmp_path): @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("subfunc", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_hash_creation(config, cb, tmp_path): +def test_causal_lm_hash_creation(config, cb, subfunc, tmp_path): model = AutoModelForCausalLM.from_config(config, **model_kwargs) qeff_model = QEFFAutoModelForCausalLM(model, cb) - qeff_model.export(tmp_path) + qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc) hash_params = {} hash_params["config"] = qeff_model.model.config.to_diff_dict() hash_params["peft_config"] = None hash_params["applied_transform_names"] = qeff_model._transform_names() hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ + hash_params["max_seq_len_cached"] = None hash_params["qaic_config"] = None # Create parameters separately for hash creation - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS @@ -190,12 +205,12 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): ) output_names = [] output_names.append("logits") - + onnx_out_name_suffix = "InternalRetainedState" if subfunc else "RetainedState" for i in range(qeff_model.num_layers): pkv_dynamic_axes[i][0] = "full_batch_size" if qeff_model.continuous_batching else "batch_size" for kv in ["key", "value"]: dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_{kv}.{i}_RetainedState") + output_names.append(f"past_{kv}.{i}_{onnx_out_name_suffix}") if qeff_model.continuous_batching: dynamic_axes["batch_index"] = {0: "batch_size"} @@ -204,14 +219,35 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): export_params["output_names"] = output_names export_params["dynamic_axes"] = dynamic_axes hash_params["export_params"] = export_params + if subfunc: + hash_params["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model) + manual_hash = hash_dict_params(hash_params) assert manual_hash == qeff_model.export_hash +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", prefill_only_configs, ids=prefill_only_config_ids) +def test_prefill_only_specialized_models(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + if cb: + with pytest.raises(NotImplementedError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + else: + with pytest.raises(ValueError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + qeff_model.export(tmp_path, prefill_only=True, prefill_seq_len=256, offload_pt_weights=False) + first_export_hash = qeff_model.export_hash + qeff_model.export(tmp_path, prefill_only=False, offload_pt_weights=False) + second_export_hash = qeff_model.export_hash + assert first_export_hash != second_export_hash + + @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/transformers/test_speech_seq2seq.py b/tests/transformers/test_speech_seq2seq.py index 59281b73b..bc53cb539 100644 --- a/tests/transformers/test_speech_seq2seq.py +++ b/tests/transformers/test_speech_seq2seq.py @@ -141,7 +141,7 @@ def test_seq2seq_hash_creation(config, tmp_path): @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py index 36cfc0ce5..6183e1282 100644 --- a/tests/transformers/test_subfunction.py +++ b/tests/transformers/test_subfunction.py @@ -44,6 +44,7 @@ config_ids = [x.model_type for x in configs] +@pytest.mark.on_qaic @pytest.mark.parametrize("config", configs, ids=config_ids) def test_subfunction_vs_nonsubfunction(config, tmp_path): tokenizer = AutoTokenizer.from_pretrained(config.model_type) diff --git a/tests/utils/test_hash_utils.py b/tests/utils/test_hash_utils.py index fefa73973..b7a5495c6 100644 --- a/tests/utils/test_hash_utils.py +++ b/tests/utils/test_hash_utils.py @@ -41,7 +41,7 @@ def test_to_hashable_float_nan(value): def test_json_serializable(): # Test with a set - assert json_serializable({1, 2, 3}) == [1, 2, 3] + assert json_serializable({1, 2, 3}) == ["1", "2", "3"] # Test with an unsupported type with pytest.raises(TypeError): json_serializable({1, 2, 3, {4, 5}})