88import gc
99import inspect
1010import logging
11- import re
1211import shutil
1312import subprocess
1413import warnings
2120
2221from QEfficient .base .onnx_transforms import (
2322 BaseOnnxTransform ,
24- CustomOpTransform ,
2523 OnnxTransformPipeline ,
26- RenameFunctionOutputsTransform ,
2724)
2825from QEfficient .base .pytorch_transforms import PytorchTransform
2926from QEfficient .compile .qnn_compiler import compile as qnn_compile
3027from QEfficient .generation .cloud_infer import QAICInferenceSession
31- from QEfficient .transformers .cache_utils import InvalidIndexProvider
32- from QEfficient .transformers .models .pytorch_transforms import get_decoder_layer_classes_for_export
3328from QEfficient .utils import (
3429 constants ,
3530 create_json ,
3631 create_model_params ,
3732 dump_qconfig ,
38- export_wrapper ,
3933 generate_mdp_partition_config ,
4034 hash_dict_params ,
4135 load_json ,
4236)
43- from QEfficient .utils .torch_patches import apply_torch_patches , undo_torch_patches
37+ from QEfficient .utils .export_utils import export_wrapper
4438
4539logger = logging .getLogger (__name__ )
4640
@@ -66,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6660 super ().__init__ ()
6761 self .model = model
6862 self .hash_params = create_model_params (self , ** kwargs )
63+ self .prefill_onnx_path : Optional [str ] = None
6964 self .onnx_path : Optional [str ] = None
7065 self .qpc_path : Optional [str ] = None
7166 self .qpc_session : Optional [QAICInferenceSession ] = None
@@ -125,9 +120,35 @@ def _model_offloaded_check(self) -> None:
125120 logger .error (error_msg )
126121 raise RuntimeError (error_msg )
127122
123+ @property
124+ def model_name (self ) -> str :
125+ """
126+ Get the model class name without QEff/QEFF prefix.
127+
128+ This property extracts the underlying model's class name and removes
129+ any QEff or QEFF prefix that may have been added during wrapping.
130+
131+ Returns:
132+ str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
133+ """
134+ mname = self .model .__class__ .__name__
135+ if mname .startswith ("QEff" ) or mname .startswith ("QEFF" ):
136+ mname = mname [4 :]
137+ return mname
138+
128139 @property
129140 @abstractmethod
130- def model_name (self ) -> str : ...
141+ def get_model_config (self ) -> Dict :
142+ """
143+ Get the model configuration as a dictionary.
144+
145+ This is an abstract property that must be implemented by all subclasses.
146+ Typically returns: self.model.config.__dict__
147+
148+ Returns:
149+ Dict: The configuration dictionary of the underlying model
150+ """
151+ pass
131152
132153 @abstractmethod
133154 def export (self , export_dir : Optional [str ] = None ) -> Path :
@@ -184,11 +205,11 @@ def _export(
184205 example_inputs : Dict [str , torch .Tensor ],
185206 output_names : List [str ],
186207 dynamic_axes : Dict [str , Dict [int , str ]],
187- export_kwargs : Optional [Dict [str , any ]] = None ,
188208 onnx_transform_kwargs : Optional [Dict [str , any ]] = None ,
189209 export_dir : Optional [str ] = None ,
190210 offload_pt_weights : bool = True ,
191- use_onnx_subfunctions : bool = False ,
211+ prefill_only : Optional [bool ] = False ,
212+ ** export_kwargs ,
192213 ) -> str :
193214 """
194215 Export the PyTorch model to ONNX and apply ONNX transforms
@@ -213,11 +234,16 @@ def _export(
213234 instance using from_pretrained() for re-export.
214235
215236 """
237+ # TODO: Hack for retain_full_kv, handle this outside
238+ export_kwargs .pop ("retain_full_kv" , None )
216239 onnx_path = export_dir / f"{ self .model_name } .onnx"
217240
218241 # Return early if ONNX already exists
219242 if onnx_path .is_file ():
220- self .onnx_path = onnx_path
243+ if prefill_only :
244+ self .prefill_onnx_path = onnx_path
245+ else :
246+ self .onnx_path = onnx_path
221247 return onnx_path
222248
223249 # check if the model is in meta state or weights are offloaded
@@ -253,19 +279,6 @@ def _export(
253279 input_names .append (param )
254280
255281 try :
256- # Initialize the registry with your custom ops
257- export_kwargs = {} if export_kwargs is None else export_kwargs
258- if use_onnx_subfunctions :
259- warnings .warn (
260- "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
261- )
262- apply_torch_patches ()
263- InvalidIndexProvider .SUBFUNC_ENABLED = True
264- output_names = [re .sub ("_RetainedState" , "_InternalRetainedState" , s ) for s in output_names ]
265- export_kwargs ["export_modules_as_functions" ] = get_decoder_layer_classes_for_export (self .model )
266- self ._onnx_transforms .append (RenameFunctionOutputsTransform )
267- self ._onnx_transforms .append (CustomOpTransform )
268-
269282 torch .onnx .export (
270283 self .model ,
271284 (example_inputs ,),
@@ -309,15 +322,42 @@ def _export(
309322 finally :
310323 shutil .rmtree (tmp_onnx_dir , ignore_errors = True )
311324
312- if use_onnx_subfunctions :
313- undo_torch_patches ()
314- InvalidIndexProvider .SUBFUNC_ENABLED = False
315- self ._onnx_transforms .remove (CustomOpTransform )
316- self ._onnx_transforms .remove (RenameFunctionOutputsTransform )
317-
318- self .onnx_path = onnx_path
325+ if prefill_only :
326+ self .prefill_onnx_path = onnx_path
327+ else :
328+ self .onnx_path = onnx_path
319329 return onnx_path
320330
331+ def get_onnx_path (
332+ self ,
333+ prefill_only : Optional [bool ] = False ,
334+ enable_chunking : Optional [bool ] = False ,
335+ specializations : Optional [List [Dict [str , int ]]] = None ,
336+ offload_pt_weights : Optional [bool ] = True ,
337+ use_onnx_subfunctions : Optional [bool ] = False ,
338+ retain_full_kv : Optional [bool ] = False ,
339+ ):
340+ kwargs = {
341+ "offload_pt_weights" : offload_pt_weights ,
342+ "use_onnx_subfunctions" : use_onnx_subfunctions ,
343+ "retain_full_kv" : retain_full_kv ,
344+ }
345+ if prefill_only :
346+ if self .prefill_onnx_path is None :
347+ kwargs .update (
348+ {
349+ "prefill_only" : prefill_only ,
350+ "prefill_seq_len" : specializations [0 ].get ("seq_len" ),
351+ "enable_chunking" : enable_chunking ,
352+ }
353+ )
354+ self .export (** kwargs )
355+ return self .prefill_onnx_path
356+ else :
357+ if self .onnx_path is None :
358+ self .export (** kwargs )
359+ return self .onnx_path
360+
321361 @dump_qconfig
322362 def _compile (
323363 self ,
@@ -332,6 +372,10 @@ def _compile(
332372 enable_qnn : Optional [bool ] = False ,
333373 qnn_config : Optional [str ] = None ,
334374 use_onnx_subfunctions : bool = False ,
375+ prefill_only : Optional [str ] = None ,
376+ offload_pt_weights : Optional [bool ] = True ,
377+ enable_chunking : Optional [bool ] = False ,
378+ retain_full_kv : Optional [bool ] = None ,
335379 ** compiler_options ,
336380 ) -> str :
337381 """
@@ -357,11 +401,18 @@ def _compile(
357401
358402 For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
359403 """
360-
361- if onnx_path is None and self .onnx_path is None :
362- self .export (use_onnx_subfunctions = use_onnx_subfunctions )
363-
364- onnx_path = Path (onnx_path or self .onnx_path )
404+ onnx_path = Path (
405+ onnx_path
406+ if onnx_path
407+ else self .get_onnx_path (
408+ prefill_only ,
409+ enable_chunking ,
410+ specializations ,
411+ offload_pt_weights ,
412+ use_onnx_subfunctions ,
413+ retain_full_kv ,
414+ )
415+ )
365416 compile_dir = Path (compile_dir or onnx_path .parent )
366417 qpc_path = compile_dir / "qpc"
367418 if not onnx_path .is_file ():
@@ -423,6 +474,7 @@ def _compile(
423474 "mdp_ts_num_devices" : mdp_ts_num_devices ,
424475 "mdp_ts_json" : mdp_ts_json ,
425476 "num_speculative_tokens" : num_speculative_tokens ,
477+ "prefill_only" : prefill_only ,
426478 }
427479 compile_hash = hash_dict_params (compile_hash_params )
428480
@@ -462,6 +514,16 @@ def _compile(
462514
463515 command .append (f"-aic-binary-dir={ qpc_path } " )
464516 logger .info (f"Running compiler: { ' ' .join (command )} " )
517+ if use_onnx_subfunctions :
518+
519+ class FeatureNotAvailableError (Exception ):
520+ pass
521+
522+ exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" { " " .join (command )} '
523+ raise FeatureNotAvailableError (
524+ "ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
525+ + f"\n Run following command manually with assert compiler:\n { exec_command } "
526+ )
465527 try :
466528 subprocess .run (command , capture_output = True , check = True )
467529 except subprocess .CalledProcessError as e :
@@ -482,5 +544,4 @@ def _compile(
482544 logger .info ("Hashed parameters exported successfully." )
483545
484546 self .qpc_path = qpc_path
485-
486547 return qpc_path
0 commit comments