Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
409da24
Extend on-device sampling support for dual QPC VLMs
quic-xiyushi Oct 23, 2025
e06e175
Fix random_numbers shape
quic-xiyushi Oct 30, 2025
3e242ce
Update example with new random sampling logic
quic-xiyushi Oct 30, 2025
1a01d57
Update to align with recent VLM CB changes
quic-xiyushi Nov 11, 2025
30d6061
Update tests with new random sampling logic
Nov 11, 2025
78ef180
Add code to perform guided decoding
Nov 11, 2025
1fafcdb
Add bitmask to example inputs and dynamic axes
Nov 12, 2025
18ab856
Rename bitmask to token_bitmasks
Nov 12, 2025
b1c049c
Fix typo
Nov 12, 2025
e16e846
Merge branch 'main' into guided_decoding_simple
Nov 19, 2025
1515497
Add flag to enable guided decoding
Nov 19, 2025
d02d04d
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 19, 2025
97e4baf
Add flag to enable guided decoding
Nov 19, 2025
7b7677b
Update test_sampler_transform for guided decoding
Nov 19, 2025
7cf106e
Refactor
quic-xiyushi Nov 19, 2025
45aed11
Add unit tests
quic-xiyushi Nov 20, 2025
6273ab5
Clean up
quic-xiyushi Nov 20, 2025
ef9ae14
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 20, 2025
60312b3
Add test for guided decoding
Nov 20, 2025
3789d5a
Update test_sampler.py
quic-xiyushi Nov 20, 2025
251099f
Merge branch 'on-device-sampling-vlm' into guided_decoding_simple
Nov 20, 2025
a24a55d
Enable guided decoding in vlm generation
Nov 20, 2025
55e76e9
Fix bug
Nov 20, 2025
f9355d4
Fix bug
Nov 20, 2025
5e2afb7
Fix hash for VLM's language decoder to include qaic_config
quic-xiyushi Nov 21, 2025
e672701
Merge branch 'on-device-sampling-vlm' into guided_decoding_simple
Nov 21, 2025
eee5314
Enable guided decoding test for vlms
Nov 21, 2025
60cf5ec
Use different config for each vlm
Nov 21, 2025
a71ee65
Update type
Nov 21, 2025
df06617
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 25, 2025
10990a9
Fix bug in getting vocab_size and missing ccl in forward
quic-xiyushi Nov 25, 2025
b47b633
Merge branch 'on-device-sampling-vlm' into guided_decoding_simple
Nov 25, 2025
b5a7b99
Merge branch 'main' into guided_decoding_simple
quic-mamta Dec 5, 2025
3fcd9eb
Merge branch 'main' into guided_decoding_simple
quic-mamta Dec 6, 2025
98cfadf
Merge branch 'main' into on-device-sampling-vlm
quic-mamta Dec 10, 2025
a60e7ce
Merge branch 'main' into on-device-sampling-vlm
quic-xiyushi Dec 16, 2025
b22af54
Support prefix-caching with on-device sampling
quic-xiyushi Dec 16, 2025
2533262
Modify tests to use internvl 1b for quicker CI
quic-xiyushi Dec 16, 2025
5457075
Merge remote-tracking branch 'origin/on-device-sampling-vlm' into HEAD
quic-xiyushi Dec 16, 2025
8698651
Merge branch 'main' into on-device-sampling-vlm
quic-xiyushi Dec 16, 2025
86aaad2
Fix compilation error on Llama3.1 8B due to changes in presence penalty
quic-xiyushi Dec 16, 2025
a2d4fb4
Update tests
quic-xiyushi Dec 16, 2025
eaf21c0
Merge remote-tracking branch 'origin/on-device-sampling-vlm' into HEAD
quic-xiyushi Dec 16, 2025
feeaa37
Extend on-device sampling support to llava, garnite, gemma, and llama4
quic-xiyushi Dec 17, 2025
5f716ef
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Dec 17, 2025
96e13a8
Merge branch 'main' into guided_decoding_simple
quic-hemagnih Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def cloud_ai_100_exec_kv(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
Expand Down Expand Up @@ -356,6 +357,8 @@ def cloud_ai_100_exec_kv(
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
:include_guided_decoding (bool, default=False): If True, enables guided token-level filtering
during decoding. Only works when `include_sampler`=True.
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
Expand Down Expand Up @@ -394,6 +397,7 @@ def cloud_ai_100_exec_kv(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)

Expand Down Expand Up @@ -442,6 +446,7 @@ def __init__(
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
activate: bool = True,
) -> None:
Expand All @@ -451,6 +456,7 @@ def __init__(
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
self.include_guided_decoding = include_guided_decoding
self.sampling_params = sampling_params
self._qpc_path = qpc_path # Store qpc_path for later use

Expand All @@ -461,7 +467,9 @@ def __init__(

# Validate sampler inputs for On-Device Sampling
self.include_sampler = validate_sampler_inputs(
session_inputs=set(self._session.input_names), include_sampler=include_sampler
session_inputs=set(self._session.input_names),
include_sampler=include_sampler,
include_guided_decoding=include_guided_decoding,
)

# Fetch the variables from the QPC
Expand Down Expand Up @@ -628,7 +636,7 @@ def prepare_decode_inputs(self):
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
Expand Down Expand Up @@ -795,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
Expand Down Expand Up @@ -1067,6 +1075,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
Expand All @@ -1082,6 +1091,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
Expand Down
8 changes: 6 additions & 2 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -115,6 +116,7 @@ def __init__(
is_tlm: Target language model flag
include_sampler: Enable on-device sampling (new feature)
return_pdfs: Return probability distributions
include_guided_decoding: Enable guided decoding in on-device sampling
sampling_params: Sampling parameters for on-device sampling
"""
# Validate required parameters
Expand All @@ -138,6 +140,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
activate=False, # vision components need to be initialized first
)
Expand Down Expand Up @@ -315,7 +318,7 @@ def _execute_chunked_prefill(
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

if self.include_sampler:
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if decode_batch_id is not None:
lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
Expand Down Expand Up @@ -348,7 +351,7 @@ def _execute_chunked_prefill(

if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
chunk_inputs[op] = lang_inputs[op]

outputs = self._session.run(chunk_inputs)
Expand Down Expand Up @@ -803,6 +806,7 @@ def generate_stream_tokens(
is_tlm=self.is_tlm,
include_sampler=self.include_sampler,
return_pdfs=self.return_pdfs,
include_guided_decoding=self.include_guided_decoding,
sampling_params=self.sampling_params,
)

Expand Down
5 changes: 4 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,6 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(
model,
Expand Down Expand Up @@ -2347,6 +2346,8 @@ def __init__(
- **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens.
For Speculative Decoding Target Language Models, this is always True.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
during decoding. Only works when include_sampler=True.
- **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation.
**kwargs :
Additional keyword arguments passed to the base class constructor.
Expand Down Expand Up @@ -2443,6 +2444,8 @@ def from_pretrained(
and ``return_pdfs=False`` for regular model.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
The values provided in ``top_ks`` tensor must be less than this maximum limit.
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
during decoding. Only works when include_sampler=True.

*args :
Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`.
Expand Down
8 changes: 8 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
QEffGemma3Attention,
QEffGemma3CustomRMSNormAIC,
QEffGemma3DecoderLayer,
QEffGemma3DecoderWrapper,
QEffGemma3ForCausalLMModel,
QEffGemma3ForConditionalGeneration,
QEffGemma3TextModel,
Expand Down Expand Up @@ -313,6 +314,7 @@
QEffLlamaRotaryEmbedding,
)
from QEfficient.transformers.models.llama4.modeling_llama4 import (
QEffLlama4DecoderWrapper,
QEffLlama4ForCausalLM,
QEffLlama4ForConditionalGeneration,
QEffLlama4Router,
Expand All @@ -325,9 +327,11 @@
QEffLlama4VisionModel,
)
from QEfficient.transformers.models.llava.modeling_llava import (
QEFFLlavaDecoderWrapper,
QEffLlavaForConditionalGeneration,
)
from QEfficient.transformers.models.llava_next.modeling_llava_next import (
QEffLlavaNextDecoderWrapper,
QEffLlavaNextForConditionalGeneration,
)
from QEfficient.transformers.models.mistral.modeling_mistral import (
Expand Down Expand Up @@ -755,12 +759,16 @@ class SamplerTransform:
_module_mapping = {
QEffFalconForCausalLM,
QEffGemmaForCausalLM,
QEffGemma3DecoderWrapper,
QEffGPT2LMHeadModel,
QEffGPTJForCausalLM,
QEffGraniteForCausalLM,
QEffGraniteMoeForCausalLM,
QEffInternDecoderWrapper,
QEffLlamaForCausalLM,
QEffLlama4DecoderWrapper,
QEFFLlavaDecoderWrapper,
QEffLlavaNextDecoderWrapper,
QEffMptForCausalLM,
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
Expand Down
13 changes: 13 additions & 0 deletions QEfficient/transformers/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def sampler_forward(
top_ps: Optional[torch.Tensor] = None,
min_ps: Optional[torch.Tensor] = None,
random_numbers: Optional[torch.Tensor] = None,
token_bitmasks: Optional[torch.Tensor] = None,
) -> Union[Tuple, SamplerOutput]:
r"""
Perform the sampling of next tokens on the QAIC device (instead of the host)
Expand Down Expand Up @@ -179,6 +180,11 @@ def sampler_forward(
random_numbers (`torch.Tensor`, *optional*):
Sampling parameter that represents the random seeds to use for random sampling.
Must be in [-1, 1].

token_bitmasks (`torch.Tensor`, *optional*):
Boolean mask used to guide token-level filtering during decoding. Each
element of this tensor indicates whether the corresponding token should be
kept (1) or masked (0). Shape: (batch_size, vocab_size)
"""
if vision_embeds is not None:
forward_kwargs = dict(
Expand Down Expand Up @@ -224,6 +230,13 @@ def sampler_forward(
batch_index = torch.arange(batch_size).view(-1, 1)

batch_index_reshaped = batch_index.view(-1)

# Guided decoding
if token_bitmasks is not None and (token_bitmasks != 1).any():
assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding"
# Mask logits where token_bitmasks is 0 with -inf
logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min)

# Prefill
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(
input_ids=input_ids,
Expand Down
15 changes: 10 additions & 5 deletions QEfficient/utils/sampler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from QEfficient.utils.logging_utils import logger


def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool:
def validate_sampler_inputs(
session_inputs: Set[str], include_sampler: Optional[bool] = None, include_guided_decoding: Optional[bool] = None
) -> bool:
"""
Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling.

Expand All @@ -31,7 +33,7 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[
ValueError if partial support is detected or if user intent conflicts with QPC capabilities.
"""

sampler_inputs = Constants.SAMPLER_INPUTS
sampler_inputs = Constants.SAMPLER_INPUTS | ({"token_bitmasks"} if include_guided_decoding else set())
count = len(sampler_inputs & session_inputs)

session_includes_sampler = True
Expand Down Expand Up @@ -96,10 +98,9 @@ def get_sampling_inputs_and_outputs(
"""
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
seq_len: int = example_inputs["input_ids"].shape[-1]

example_inputs["last_accepted_output_tokens"] = torch.zeros(
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
)
example_inputs["last_accepted_output_tokens"] = torch.zeros((bs, seq_len), dtype=torch.int64)
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}

example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
Expand Down Expand Up @@ -144,4 +145,8 @@ def get_sampling_inputs_and_outputs(
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

if qaic_config.get("include_guided_decoding", False):
example_inputs["token_bitmasks"] = torch.zeros((bs, vocab_size), dtype=torch.bool)
dynamic_axes["token_bitmasks"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes
43 changes: 39 additions & 4 deletions examples/performance/on_device_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def main(args, **kwargs):
include_sampler = None
return_pdfs = None
max_top_k_ids = None
include_guided_decoding = None
sampling_params = None
bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size
if args.override_qaic_config is not None:
Expand All @@ -29,6 +30,7 @@ def main(args, **kwargs):
return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true"
max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512))
np.random.seed(int(args.random_number))
include_guided_decoding = args.override_qaic_config.get("aic_include_guided_decoding", None) == "true"
sampling_params = {
"repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
"presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
Expand All @@ -47,13 +49,12 @@ def main(args, **kwargs):
"include_sampler": include_sampler,
"return_pdfs": return_pdfs,
"max_top_k_ids": max_top_k_ids,
"include_guided_decoding": include_guided_decoding,
}.items()
if v is not None
}
print("qaic_config:")
pprint(qaic_config)
print("sampling_params:")
pprint(sampling_params)

# Load model with On Device Sampler enabled
qeff_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -63,6 +64,19 @@ def main(args, **kwargs):
)
print(f"{args.model_name} optimized for AI 100 \n", qeff_model)

if include_guided_decoding:
# Ideally this should come from a logits processor like xgrammar, but for the sake of the
# example, we generate a random bitmask
sampling_params.update(
{
"token_bitmasks": np.tile(
np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1)
)
}
)
print("sampling_params:")
pprint(sampling_params)

# Compile the model for inference
generated_qpc_path = qeff_model.compile(
prefill_seq_len=args.prompt_len,
Expand Down Expand Up @@ -91,6 +105,7 @@ def main(args, **kwargs):
generation_len=args.generation_len,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)

Expand All @@ -109,7 +124,7 @@ def main(args, **kwargs):
--num-cores 16 \
--mxint8-kv-cache \
--mxfp6-matmul \
--override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
--override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \
--repetition-penalty 1.9 \
--presence-penalty 0.8 \
--temperature 0.67 \
Expand All @@ -129,7 +144,27 @@ def main(args, **kwargs):
--num-cores 16 \
--mxint8-kv-cache \
--mxfp6-matmul \
--override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
--override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \
--repetition-penalty 1.9 \
--presence-penalty 0.8 \
--temperature 0.67 \
--top-k 54 \
--top-p 0.89 \
--min-p 0.6 \
--random-number 26

3. With guided decoding:
python3.10 examples/on_device_sampling.py \
--model-name 'meta-llama/Llama-3.1-8B' \
--prompt-len 128 \
--ctx-len 256 \
--generation-len 20 \
--full-batch-size 2 \
--device-group [0,1,2,3] \
--num-cores 16 \
--mxint8-kv-cache \
--mxfp6-matmul \
--override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:true" \
--repetition-penalty 1.9 \
--presence-penalty 0.8 \
--temperature 0.67 \
Expand Down
Loading
Loading