Skip to content

Commit ac48615

Browse files
committed
Merge remote-tracking branch 'origin/on-device-sampling-vlm' into HEAD
2 parents 3fcd9eb + 2533262 commit ac48615

File tree

81 files changed

+9011
-526
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+9011
-526
lines changed

QEfficient/__init__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,17 @@
66
# -----------------------------------------------------------------------------
77

88
import os
9-
import warnings
9+
10+
# ----------------------------------------------------------------------------- #
11+
# For faster downloads via hf_transfer
12+
# This code is put above import statements as this needs to be executed before
13+
# hf_transfer is imported (will happen on line 15 via leading imports)
14+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15+
# DO NOT ADD ANY CODE ABOVE THIS LINE
16+
# Please contact maintainers if you must edit this file above this line.
17+
# ----------------------------------------------------------------------------- #
18+
# Placeholder for all non-transformer models registered in QEfficient
19+
import warnings # noqa: I001
1020

1121
import QEfficient.utils.model_registery # noqa: F401
1222
from QEfficient.base import (
@@ -18,13 +28,18 @@
1828
QEFFCommonLoader,
1929
)
2030
from QEfficient.compile.compile_helper import compile
31+
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
2132
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
2233
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
2334
from QEfficient.peft import QEffAutoPeftModelForCausalLM
2435
from QEfficient.transformers.transform import transform
2536
from QEfficient.utils import custom_format_warning
2637
from QEfficient.utils.logging_utils import logger
2738

39+
# custom warning for the better logging experience
40+
warnings.formatwarning = custom_format_warning
41+
42+
2843
# Users can use QEfficient.export for exporting models to ONNX
2944
export = qualcomm_efficient_converter
3045
__all__ = [
@@ -39,15 +54,9 @@
3954
"QEFFAutoModelForImageTextToText",
4055
"QEFFAutoModelForSpeechSeq2Seq",
4156
"QEFFCommonLoader",
57+
"QEffFluxPipeline",
4258
]
43-
# For faster downloads via hf_transfer
44-
# This code is put above import statements as this needs to be executed before
45-
# hf_transfer is imported (will happen on line 15 via leading imports)
46-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
47-
# Placeholder for all non-transformer models registered in QEfficient
4859

49-
# custom warning for the better logging experience
50-
warnings.formatwarning = custom_format_warning
5160

5261
# Conditionally import QAIC-related modules if the SDK is installed
5362
__version__ = "0.0.1.dev0"

QEfficient/base/modeling_qeff.py

Lines changed: 98 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import gc
99
import inspect
1010
import logging
11-
import re
1211
import shutil
1312
import subprocess
1413
import warnings
@@ -21,26 +20,21 @@
2120

2221
from QEfficient.base.onnx_transforms import (
2322
BaseOnnxTransform,
24-
CustomOpTransform,
2523
OnnxTransformPipeline,
26-
RenameFunctionOutputsTransform,
2724
)
2825
from QEfficient.base.pytorch_transforms import PytorchTransform
2926
from QEfficient.compile.qnn_compiler import compile as qnn_compile
3027
from QEfficient.generation.cloud_infer import QAICInferenceSession
31-
from QEfficient.transformers.cache_utils import InvalidIndexProvider
32-
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
3328
from QEfficient.utils import (
3429
constants,
3530
create_json,
3631
create_model_params,
3732
dump_qconfig,
38-
export_wrapper,
3933
generate_mdp_partition_config,
4034
hash_dict_params,
4135
load_json,
4236
)
43-
from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
37+
from QEfficient.utils.export_utils import export_wrapper
4438

4539
logger = logging.getLogger(__name__)
4640

@@ -66,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6660
super().__init__()
6761
self.model = model
6862
self.hash_params = create_model_params(self, **kwargs)
63+
self.prefill_onnx_path: Optional[str] = None
6964
self.onnx_path: Optional[str] = None
7065
self.qpc_path: Optional[str] = None
7166
self.qpc_session: Optional[QAICInferenceSession] = None
@@ -125,9 +120,35 @@ def _model_offloaded_check(self) -> None:
125120
logger.error(error_msg)
126121
raise RuntimeError(error_msg)
127122

123+
@property
124+
def model_name(self) -> str:
125+
"""
126+
Get the model class name without QEff/QEFF prefix.
127+
128+
This property extracts the underlying model's class name and removes
129+
any QEff or QEFF prefix that may have been added during wrapping.
130+
131+
Returns:
132+
str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
133+
"""
134+
mname = self.model.__class__.__name__
135+
if mname.startswith("QEff") or mname.startswith("QEFF"):
136+
mname = mname[4:]
137+
return mname
138+
128139
@property
129140
@abstractmethod
130-
def model_name(self) -> str: ...
141+
def get_model_config(self) -> Dict:
142+
"""
143+
Get the model configuration as a dictionary.
144+
145+
This is an abstract property that must be implemented by all subclasses.
146+
Typically returns: self.model.config.__dict__
147+
148+
Returns:
149+
Dict: The configuration dictionary of the underlying model
150+
"""
151+
pass
131152

132153
@abstractmethod
133154
def export(self, export_dir: Optional[str] = None) -> Path:
@@ -184,11 +205,11 @@ def _export(
184205
example_inputs: Dict[str, torch.Tensor],
185206
output_names: List[str],
186207
dynamic_axes: Dict[str, Dict[int, str]],
187-
export_kwargs: Optional[Dict[str, any]] = None,
188208
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
189209
export_dir: Optional[str] = None,
190210
offload_pt_weights: bool = True,
191-
use_onnx_subfunctions: bool = False,
211+
prefill_only: Optional[bool] = False,
212+
**export_kwargs,
192213
) -> str:
193214
"""
194215
Export the PyTorch model to ONNX and apply ONNX transforms
@@ -213,11 +234,16 @@ def _export(
213234
instance using from_pretrained() for re-export.
214235
215236
"""
237+
# TODO: Hack for retain_full_kv, handle this outside
238+
export_kwargs.pop("retain_full_kv", None)
216239
onnx_path = export_dir / f"{self.model_name}.onnx"
217240

218241
# Return early if ONNX already exists
219242
if onnx_path.is_file():
220-
self.onnx_path = onnx_path
243+
if prefill_only:
244+
self.prefill_onnx_path = onnx_path
245+
else:
246+
self.onnx_path = onnx_path
221247
return onnx_path
222248

223249
# check if the model is in meta state or weights are offloaded
@@ -253,19 +279,6 @@ def _export(
253279
input_names.append(param)
254280

255281
try:
256-
# Initialize the registry with your custom ops
257-
export_kwargs = {} if export_kwargs is None else export_kwargs
258-
if use_onnx_subfunctions:
259-
warnings.warn(
260-
"The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
261-
)
262-
apply_torch_patches()
263-
InvalidIndexProvider.SUBFUNC_ENABLED = True
264-
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
265-
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
266-
self._onnx_transforms.append(RenameFunctionOutputsTransform)
267-
self._onnx_transforms.append(CustomOpTransform)
268-
269282
torch.onnx.export(
270283
self.model,
271284
(example_inputs,),
@@ -309,15 +322,42 @@ def _export(
309322
finally:
310323
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
311324

312-
if use_onnx_subfunctions:
313-
undo_torch_patches()
314-
InvalidIndexProvider.SUBFUNC_ENABLED = False
315-
self._onnx_transforms.remove(CustomOpTransform)
316-
self._onnx_transforms.remove(RenameFunctionOutputsTransform)
317-
318-
self.onnx_path = onnx_path
325+
if prefill_only:
326+
self.prefill_onnx_path = onnx_path
327+
else:
328+
self.onnx_path = onnx_path
319329
return onnx_path
320330

331+
def get_onnx_path(
332+
self,
333+
prefill_only: Optional[bool] = False,
334+
enable_chunking: Optional[bool] = False,
335+
specializations: Optional[List[Dict[str, int]]] = None,
336+
offload_pt_weights: Optional[bool] = True,
337+
use_onnx_subfunctions: Optional[bool] = False,
338+
retain_full_kv: Optional[bool] = False,
339+
):
340+
kwargs = {
341+
"offload_pt_weights": offload_pt_weights,
342+
"use_onnx_subfunctions": use_onnx_subfunctions,
343+
"retain_full_kv": retain_full_kv,
344+
}
345+
if prefill_only:
346+
if self.prefill_onnx_path is None:
347+
kwargs.update(
348+
{
349+
"prefill_only": prefill_only,
350+
"prefill_seq_len": specializations[0].get("seq_len"),
351+
"enable_chunking": enable_chunking,
352+
}
353+
)
354+
self.export(**kwargs)
355+
return self.prefill_onnx_path
356+
else:
357+
if self.onnx_path is None:
358+
self.export(**kwargs)
359+
return self.onnx_path
360+
321361
@dump_qconfig
322362
def _compile(
323363
self,
@@ -332,6 +372,10 @@ def _compile(
332372
enable_qnn: Optional[bool] = False,
333373
qnn_config: Optional[str] = None,
334374
use_onnx_subfunctions: bool = False,
375+
prefill_only: Optional[str] = None,
376+
offload_pt_weights: Optional[bool] = True,
377+
enable_chunking: Optional[bool] = False,
378+
retain_full_kv: Optional[bool] = None,
335379
**compiler_options,
336380
) -> str:
337381
"""
@@ -357,11 +401,18 @@ def _compile(
357401
358402
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
359403
"""
360-
361-
if onnx_path is None and self.onnx_path is None:
362-
self.export(use_onnx_subfunctions=use_onnx_subfunctions)
363-
364-
onnx_path = Path(onnx_path or self.onnx_path)
404+
onnx_path = Path(
405+
onnx_path
406+
if onnx_path
407+
else self.get_onnx_path(
408+
prefill_only,
409+
enable_chunking,
410+
specializations,
411+
offload_pt_weights,
412+
use_onnx_subfunctions,
413+
retain_full_kv,
414+
)
415+
)
365416
compile_dir = Path(compile_dir or onnx_path.parent)
366417
qpc_path = compile_dir / "qpc"
367418
if not onnx_path.is_file():
@@ -423,6 +474,7 @@ def _compile(
423474
"mdp_ts_num_devices": mdp_ts_num_devices,
424475
"mdp_ts_json": mdp_ts_json,
425476
"num_speculative_tokens": num_speculative_tokens,
477+
"prefill_only": prefill_only,
426478
}
427479
compile_hash = hash_dict_params(compile_hash_params)
428480

@@ -462,6 +514,16 @@ def _compile(
462514

463515
command.append(f"-aic-binary-dir={qpc_path}")
464516
logger.info(f"Running compiler: {' '.join(command)}")
517+
if use_onnx_subfunctions:
518+
519+
class FeatureNotAvailableError(Exception):
520+
pass
521+
522+
exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}'
523+
raise FeatureNotAvailableError(
524+
"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
525+
+ f"\nRun following command manually with assert compiler:\n{exec_command}"
526+
)
465527
try:
466528
subprocess.run(command, capture_output=True, check=True)
467529
except subprocess.CalledProcessError as e:
@@ -482,5 +544,4 @@ def _compile(
482544
logger.info("Hashed parameters exported successfully.")
483545

484546
self.qpc_path = qpc_path
485-
486547
return qpc_path

QEfficient/base/onnx_transforms.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@
1919
from QEfficient.customop.ctx_scatter_gather import (
2020
CtxGather,
2121
CtxGather3D,
22+
CtxGatherBlockedKV,
2223
CtxGatherFunc,
2324
CtxGatherFunc3D,
25+
CtxGatherFuncBlockedKV,
2426
CtxScatter,
2527
CtxScatter3D,
2628
CtxScatterFunc,
2729
CtxScatterFunc3D,
2830
)
2931
from QEfficient.customop.ctx_scatter_gather_cb import (
32+
CtxGatherBlockedKVCB,
3033
CtxGatherCB,
3134
CtxGatherCB3D,
35+
CtxGatherFuncBlockedKVCB,
3236
CtxGatherFuncCB,
3337
CtxGatherFuncCB3D,
3438
CtxScatterCB,
@@ -91,10 +95,12 @@ class CustomOpTransform(BaseOnnxTransform):
9195
"CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D),
9296
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
9397
"CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D),
94-
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
9598
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
96-
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
9799
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
100+
"CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV),
101+
"CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
102+
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
103+
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
98104
}
99105

100106
@classmethod

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function):
136136
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
137137
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
138138
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
139+
ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
139140
return data[batch_indices, head_indices, ctx_indices]
140141

141142
@staticmethod

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function):
126126
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
127127
batch_indices = batch_index.view(-1, 1, 1)
128128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
129+
ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices)
129130
return data[batch_indices, head_indices, ctx_indices]
130131

131132
@staticmethod

0 commit comments

Comments
 (0)