Skip to content

Commit b4ff8a0

Browse files
committed
Merge remote-tracking branch 'origin/on-device-sampling-vlm' into HEAD
2 parents 5457075 + a2d4fb4 commit b4ff8a0

File tree

3 files changed

+92
-97
lines changed

3 files changed

+92
-97
lines changed

QEfficient/transformers/sampler/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def prefill_path(
3737
batch_index: torch.LongTensor,
3838
batch_index_reshaped: torch.LongTensor,
3939
past_repetition_penalty_buffer: torch.Tensor,
40+
past_presence_penalty_buffer: torch.Tensor,
4041
) -> Tuple[torch.Tensor, torch.Tensor]:
4142
"""
4243
Initialize or update RetainedState buffers for prefill stage based on `input_ids`.
@@ -59,7 +60,10 @@ def prefill_path(
5960
input_ids,
6061
torch.ones(input_ids.shape, dtype=torch.bool),
6162
)
62-
return past_repetition_penalty_buffer
63+
64+
mul_value = torch.zeros(past_presence_penalty_buffer.shape[0], 1, dtype=torch.bool)
65+
past_presence_penalty_buffer *= mul_value
66+
return past_repetition_penalty_buffer, past_presence_penalty_buffer
6367

6468

6569
def decode_path(
@@ -234,14 +238,14 @@ def sampler_forward(
234238
logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min)
235239

236240
# Prefill
237-
past_repetition_penalty_buffer_prefill = prefill_path(
241+
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(
238242
input_ids=input_ids,
239243
position_ids=position_ids,
240244
batch_index=batch_index,
241245
batch_index_reshaped=batch_index_reshaped,
242246
past_repetition_penalty_buffer=past_repetition_penalty_buffer.clone(),
247+
past_presence_penalty_buffer=past_presence_penalty_buffer.clone(),
243248
)
244-
past_presence_penalty_buffer_prefill = torch.zeros(past_presence_penalty_buffer.shape, dtype=torch.bool)
245249
# Decode
246250
past_repetition_penalty_buffer_decode, past_presence_penalty_buffer_decode = decode_path(
247251
last_accepted_output_tokens=last_accepted_output_tokens,

tests/transformers/sampler/test_sampler.py

Lines changed: 77 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import List, Union
8+
from typing import List, Optional, Tuple, Union
99

1010
import numpy as np
1111
import pytest
12-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
12+
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
1313

1414
from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText
1515
from QEfficient.generation.cloud_infer import QAICInferenceSession
@@ -144,6 +144,42 @@
144144
]
145145

146146

147+
def prepare_model_setup(
148+
model: str, is_vlm: bool, num_hidden_layers: Optional[int], prompts: Union[List, Tuple], spec_length: Optional[int]
149+
):
150+
additional_configs = {}
151+
additional_params = {}
152+
if is_vlm:
153+
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
154+
if num_hidden_layers is not None:
155+
config.llm_config.num_hidden_layers = num_hidden_layers
156+
additional_configs["config"] = config
157+
additional_configs["kv_offload"] = True
158+
assert isinstance(prompts, tuple), "For VLMs, both image and text prompts must be provided."
159+
additional_params["images"] = prompts[0]
160+
prompts = prompts[1]
161+
162+
if "InternVL" in model:
163+
additional_configs["trust_remote_code"] = True
164+
model_hf = AutoModelForCausalLM.from_pretrained(
165+
model,
166+
config=config,
167+
trust_remote_code=True,
168+
)
169+
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
170+
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
171+
qeff_class = QEFFAutoModelForCausalLM
172+
else:
173+
additional_params["processor"] = AutoProcessor.from_pretrained(model)
174+
qeff_class = QEFFAutoModelForImageTextToText
175+
else:
176+
if num_hidden_layers is not None:
177+
additional_configs["num_hidden_layers"] = num_hidden_layers
178+
spec_length = (spec_length or 1) - 1
179+
qeff_class = QEFFAutoModelForCausalLM
180+
return additional_configs, additional_params, prompts, spec_length, qeff_class
181+
182+
147183
@pytest.mark.on_qaic
148184
@pytest.mark.parametrize(
149185
"model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
@@ -156,7 +192,7 @@ def test_sampler_transform(
156192
ctx_len: int,
157193
generation_len: int,
158194
full_batch_size: int,
159-
spec_length: int,
195+
spec_length: Optional[int],
160196
is_vlm: bool,
161197
):
162198
"""
@@ -165,22 +201,10 @@ def test_sampler_transform(
165201
next tokens and/or probability distributions.
166202
"""
167203
# Export and compile QEfficient models
168-
additional_configs = {}
169-
if is_vlm:
170-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
171-
if "InternVL" in model:
172-
config.llm_config.num_hidden_layers = 2
173-
qeff_class = QEFFAutoModelForCausalLM
174-
additional_configs["trust_remote_code"] = True
175-
else:
176-
config.text_config.num_hidden_layers = 2
177-
qeff_class = QEFFAutoModelForImageTextToText
178-
additional_configs["config"] = config
179-
additional_configs["kv_offload"] = True
180-
else:
181-
additional_configs["num_hidden_layers"] = 2
182-
qeff_class = QEFFAutoModelForCausalLM
183-
spec_length -= 1
204+
num_hidden_layers = 2
205+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
206+
model, is_vlm, num_hidden_layers, prompts, spec_length
207+
)
184208
model_w_sampler = qeff_class.from_pretrained(
185209
model,
186210
continuous_batching=True,
@@ -298,33 +322,17 @@ def test_greedy_sampling(
298322
ctx_len: int,
299323
generation_len: int,
300324
full_batch_size: int,
301-
spec_length: int,
325+
spec_length: Optional[int],
302326
is_vlm: bool,
303327
):
304328
"""
305329
Test greedy sampling with QPCs compiled with and without On Device Sampling.
306330
"""
307331
# Export and compile QEfficient models
308-
additional_configs = {}
309-
additional_params = {}
310-
if is_vlm:
311-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
312-
config.llm_config.num_hidden_layers = 4
313-
additional_configs["config"] = config
314-
additional_configs["kv_offload"] = True
315-
additional_configs["trust_remote_code"] = True
316-
model_hf = AutoModelForCausalLM.from_pretrained(
317-
model,
318-
config=config,
319-
)
320-
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
321-
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
322-
assert isinstance(prompts, tuple)
323-
additional_params["images"] = prompts[0]
324-
prompts = prompts[1]
325-
else:
326-
additional_configs["num_hidden_layers"] = 4
327-
spec_length -= 1
332+
num_hidden_layers = 4
333+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
334+
model, is_vlm, num_hidden_layers, prompts, spec_length
335+
)
328336
model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
329337
model,
330338
continuous_batching=True,
@@ -416,30 +424,17 @@ def test_random_sampling(
416424
ctx_len: int,
417425
generation_len: int,
418426
full_batch_size: int,
419-
spec_length: int,
427+
spec_length: Optional[int],
420428
is_vlm: bool,
421429
):
422430
"""
423431
Test random sampling with QPCs compiled with and without On Device Sampling.
424432
"""
425433
# Export and compile QEfficient models
426-
additional_configs = {}
427-
additional_params = {}
428-
if is_vlm:
429-
additional_configs["kv_offload"] = True
430-
additional_configs["trust_remote_code"] = True
431-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
432-
model_hf = AutoModelForCausalLM.from_pretrained(
433-
model,
434-
config=config,
435-
)
436-
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
437-
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
438-
assert isinstance(prompts, tuple)
439-
additional_params["images"] = prompts[0]
440-
prompts = prompts[1]
441-
else:
442-
spec_length -= 1
434+
num_hidden_layers = None
435+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
436+
model, is_vlm, num_hidden_layers, prompts, spec_length
437+
)
443438
model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
444439
model,
445440
continuous_batching=True,
@@ -571,7 +566,7 @@ def test_random_sampling(
571566
}
572567
elif model == "OpenGVLab/InternVL2_5-1B":
573568
golden_texts = {
574-
"w_sampler": "The description of this black puppy:\n\nThis scene features a small, young dog with smooth and shiny fur",
569+
"w_sampler": "The description of this picture would be as follows:\n\nAn adorable black puppy is sitting on a wooden surface",
575570
"wo_sampler": "The image features a black puppy sitting on a wooden surface. The puppy has a shiny, glossy coat",
576571
}
577572
golden_ids = {
@@ -581,22 +576,22 @@ def test_random_sampling(
581576
4008,
582577
315,
583578
419,
579+
6802,
580+
1035,
581+
387,
582+
438,
583+
11017,
584+
1447,
585+
2082,
586+
40608,
584587
3691,
585588
41189,
586-
1447,
587-
1986,
588-
6109,
589-
4419,
589+
374,
590+
11699,
591+
389,
590592
264,
591-
2613,
592-
11,
593-
3908,
594-
5562,
595-
448,
596-
10876,
597-
323,
598-
41199,
599-
18241,
593+
22360,
594+
7329,
600595
]
601596
],
602597
"wo_sampler": [
@@ -651,33 +646,17 @@ def test_guided_decoding(
651646
ctx_len: int,
652647
generation_len: int,
653648
full_batch_size: int,
654-
spec_length: int,
649+
spec_length: Optional[int],
655650
is_vlm: bool,
656651
):
657652
"""
658653
Test QPCs compiled with and without guided decoding.
659654
"""
660655
# Export and compile QEfficient models
661-
additional_configs = {}
662-
additional_params = {}
663-
if is_vlm:
664-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
665-
config.llm_config.num_hidden_layers = 2
666-
additional_configs["config"] = config
667-
additional_configs["kv_offload"] = True
668-
additional_configs["trust_remote_code"] = True
669-
model_hf = AutoModelForCausalLM.from_pretrained(
670-
model,
671-
config=config,
672-
)
673-
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
674-
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
675-
assert isinstance(prompts, tuple)
676-
additional_params["images"] = prompts[0]
677-
prompts = prompts[1]
678-
else:
679-
additional_configs["num_hidden_layers"] = 2
680-
spec_length -= 1
656+
num_hidden_layers = 2
657+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
658+
model, is_vlm, num_hidden_layers, prompts, spec_length
659+
)
681660
model_w_sampler_w_guided_decoding = QEFFAutoModelForCausalLM.from_pretrained(
682661
model,
683662
continuous_batching=True,
@@ -733,6 +712,10 @@ def test_guided_decoding(
733712
"min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
734713
"random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32),
735714
}
715+
if is_vlm:
716+
vocab_size = model_w_sampler_w_guided_decoding.model.language_model.config.vocab_size
717+
else:
718+
vocab_size = model_w_sampler_w_guided_decoding.model.config.vocab_size
736719
model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate(
737720
tokenizer=tokenizer,
738721
prompts=prompts,
@@ -744,7 +727,7 @@ def test_guided_decoding(
744727
**sampling_params,
745728
**{
746729
"token_bitmasks": np.tile(
747-
np.random.choice([True, False], size=(model_w_sampler_w_guided_decoding.model.config.vocab_size,)),
730+
np.random.choice([True, False], size=(vocab_size,)),
748731
(full_batch_size, 1),
749732
)
750733
},

tests/transformers/test_subfunction.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#
66
# ----------------------------------------------------------------------------
77

8+
import hashlib
9+
810
import pytest
911
import torch
1012
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -57,8 +59,14 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path):
5759
without_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=False)
5860
hash_0_1 = model_0_0.export_hash
5961

62+
# Test that the export hash changes when use_onnx_subfunction is toggled, indicating different parameters are used
6063
assert hash_0_0 != hash_0_1
6164

65+
# Test that the exported ONNX files hash are different by comparing their hashes when use_onnx_subfunction is toggled
66+
with_sub_func_onnx_hash = hashlib.sha256(open(with_sub_func_onnx, "rb").read()).hexdigest()
67+
without_sub_func_onnx_hash = hashlib.sha256(open(without_sub_func_onnx, "rb").read()).hexdigest()
68+
assert with_sub_func_onnx_hash != without_sub_func_onnx_hash
69+
6270
compile_params = {"prefill_seq_len": 8, "ctx_len": 16}
6371
model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params)
6472
generation_00 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer)

0 commit comments

Comments
 (0)