Skip to content

Commit eaf21c0

Browse files
committed
Merge remote-tracking branch 'origin/on-device-sampling-vlm' into HEAD
2 parents 5457075 + a2d4fb4 commit eaf21c0

File tree

3 files changed

+92
-122
lines changed

3 files changed

+92
-122
lines changed

QEfficient/transformers/sampler/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def prefill_path(
3737
batch_index: torch.LongTensor,
3838
batch_index_reshaped: torch.LongTensor,
3939
past_repetition_penalty_buffer: torch.Tensor,
40+
past_presence_penalty_buffer: torch.Tensor,
4041
) -> Tuple[torch.Tensor, torch.Tensor]:
4142
"""
4243
Initialize or update RetainedState buffers for prefill stage based on `input_ids`.
@@ -59,7 +60,10 @@ def prefill_path(
5960
input_ids,
6061
torch.ones(input_ids.shape, dtype=torch.bool),
6162
)
62-
return past_repetition_penalty_buffer
63+
64+
mul_value = torch.zeros(past_presence_penalty_buffer.shape[0], 1, dtype=torch.bool)
65+
past_presence_penalty_buffer *= mul_value
66+
return past_repetition_penalty_buffer, past_presence_penalty_buffer
6367

6468

6569
def decode_path(
@@ -234,14 +238,14 @@ def sampler_forward(
234238
logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min)
235239

236240
# Prefill
237-
past_repetition_penalty_buffer_prefill = prefill_path(
241+
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(
238242
input_ids=input_ids,
239243
position_ids=position_ids,
240244
batch_index=batch_index,
241245
batch_index_reshaped=batch_index_reshaped,
242246
past_repetition_penalty_buffer=past_repetition_penalty_buffer.clone(),
247+
past_presence_penalty_buffer=past_presence_penalty_buffer.clone(),
243248
)
244-
past_presence_penalty_buffer_prefill = torch.zeros(past_presence_penalty_buffer.shape, dtype=torch.bool)
245249
# Decode
246250
past_repetition_penalty_buffer_decode, past_presence_penalty_buffer_decode = decode_path(
247251
last_accepted_output_tokens=last_accepted_output_tokens,

tests/transformers/sampler/test_sampler.py

Lines changed: 77 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import List, Union
8+
from typing import List, Optional, Tuple, Union
99

1010
import numpy as np
1111
import pytest
12-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
12+
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
1313

1414
from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText
1515
from QEfficient.generation.cloud_infer import QAICInferenceSession
@@ -117,31 +117,42 @@
117117
True, # is_vlm
118118
),
119119
]
120-
guided_decoding_configs = [
121-
pytest.param(
122-
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model
123-
Constants.INPUT_STR * 4, # prompts
124-
32, # prefill_seq_len
125-
64, # ctx_len
126-
20, # generation_len
127-
4, # full_batch_size
128-
1, # spec_length
129-
False, # is_vlm
130-
),
131-
pytest.param(
132-
"OpenGVLab/InternVL2_5-1B", # model
133-
(
134-
["https://picsum.photos/id/237/536/354"] * 4,
135-
["Can you describe the image in detail."] * 4,
136-
), # images and prompts
137-
128, # prefill_seq_len
138-
4096, # ctx_len
139-
20, # generation_len
140-
4, # full_batch_size
141-
None, # spec_length
142-
True, # is_vlm
143-
),
144-
]
120+
121+
122+
def prepare_model_setup(
123+
model: str, is_vlm: bool, num_hidden_layers: Optional[int], prompts: Union[List, Tuple], spec_length: Optional[int]
124+
):
125+
additional_configs = {}
126+
additional_params = {}
127+
if is_vlm:
128+
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
129+
if num_hidden_layers is not None:
130+
config.llm_config.num_hidden_layers = num_hidden_layers
131+
additional_configs["config"] = config
132+
additional_configs["kv_offload"] = True
133+
assert isinstance(prompts, tuple), "For VLMs, both image and text prompts must be provided."
134+
additional_params["images"] = prompts[0]
135+
prompts = prompts[1]
136+
137+
if "InternVL" in model:
138+
additional_configs["trust_remote_code"] = True
139+
model_hf = AutoModelForCausalLM.from_pretrained(
140+
model,
141+
config=config,
142+
trust_remote_code=True,
143+
)
144+
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
145+
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
146+
qeff_class = QEFFAutoModelForCausalLM
147+
else:
148+
additional_params["processor"] = AutoProcessor.from_pretrained(model)
149+
qeff_class = QEFFAutoModelForImageTextToText
150+
else:
151+
if num_hidden_layers is not None:
152+
additional_configs["num_hidden_layers"] = num_hidden_layers
153+
spec_length = (spec_length or 1) - 1
154+
qeff_class = QEFFAutoModelForCausalLM
155+
return additional_configs, additional_params, prompts, spec_length, qeff_class
145156

146157

147158
@pytest.mark.on_qaic
@@ -156,7 +167,7 @@ def test_sampler_transform(
156167
ctx_len: int,
157168
generation_len: int,
158169
full_batch_size: int,
159-
spec_length: int,
170+
spec_length: Optional[int],
160171
is_vlm: bool,
161172
):
162173
"""
@@ -165,22 +176,10 @@ def test_sampler_transform(
165176
next tokens and/or probability distributions.
166177
"""
167178
# Export and compile QEfficient models
168-
additional_configs = {}
169-
if is_vlm:
170-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
171-
if "InternVL" in model:
172-
config.llm_config.num_hidden_layers = 2
173-
qeff_class = QEFFAutoModelForCausalLM
174-
additional_configs["trust_remote_code"] = True
175-
else:
176-
config.text_config.num_hidden_layers = 2
177-
qeff_class = QEFFAutoModelForImageTextToText
178-
additional_configs["config"] = config
179-
additional_configs["kv_offload"] = True
180-
else:
181-
additional_configs["num_hidden_layers"] = 2
182-
qeff_class = QEFFAutoModelForCausalLM
183-
spec_length -= 1
179+
num_hidden_layers = 2
180+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
181+
model, is_vlm, num_hidden_layers, prompts, spec_length
182+
)
184183
model_w_sampler = qeff_class.from_pretrained(
185184
model,
186185
continuous_batching=True,
@@ -298,33 +297,17 @@ def test_greedy_sampling(
298297
ctx_len: int,
299298
generation_len: int,
300299
full_batch_size: int,
301-
spec_length: int,
300+
spec_length: Optional[int],
302301
is_vlm: bool,
303302
):
304303
"""
305304
Test greedy sampling with QPCs compiled with and without On Device Sampling.
306305
"""
307306
# Export and compile QEfficient models
308-
additional_configs = {}
309-
additional_params = {}
310-
if is_vlm:
311-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
312-
config.llm_config.num_hidden_layers = 4
313-
additional_configs["config"] = config
314-
additional_configs["kv_offload"] = True
315-
additional_configs["trust_remote_code"] = True
316-
model_hf = AutoModelForCausalLM.from_pretrained(
317-
model,
318-
config=config,
319-
)
320-
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
321-
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
322-
assert isinstance(prompts, tuple)
323-
additional_params["images"] = prompts[0]
324-
prompts = prompts[1]
325-
else:
326-
additional_configs["num_hidden_layers"] = 4
327-
spec_length -= 1
307+
num_hidden_layers = 4
308+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
309+
model, is_vlm, num_hidden_layers, prompts, spec_length
310+
)
328311
model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
329312
model,
330313
continuous_batching=True,
@@ -416,30 +399,17 @@ def test_random_sampling(
416399
ctx_len: int,
417400
generation_len: int,
418401
full_batch_size: int,
419-
spec_length: int,
402+
spec_length: Optional[int],
420403
is_vlm: bool,
421404
):
422405
"""
423406
Test random sampling with QPCs compiled with and without On Device Sampling.
424407
"""
425408
# Export and compile QEfficient models
426-
additional_configs = {}
427-
additional_params = {}
428-
if is_vlm:
429-
additional_configs["kv_offload"] = True
430-
additional_configs["trust_remote_code"] = True
431-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
432-
model_hf = AutoModelForCausalLM.from_pretrained(
433-
model,
434-
config=config,
435-
)
436-
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
437-
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
438-
assert isinstance(prompts, tuple)
439-
additional_params["images"] = prompts[0]
440-
prompts = prompts[1]
441-
else:
442-
spec_length -= 1
409+
num_hidden_layers = None
410+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
411+
model, is_vlm, num_hidden_layers, prompts, spec_length
412+
)
443413
model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
444414
model,
445415
continuous_batching=True,
@@ -571,7 +541,7 @@ def test_random_sampling(
571541
}
572542
elif model == "OpenGVLab/InternVL2_5-1B":
573543
golden_texts = {
574-
"w_sampler": "The description of this black puppy:\n\nThis scene features a small, young dog with smooth and shiny fur",
544+
"w_sampler": "The description of this picture would be as follows:\n\nAn adorable black puppy is sitting on a wooden surface",
575545
"wo_sampler": "The image features a black puppy sitting on a wooden surface. The puppy has a shiny, glossy coat",
576546
}
577547
golden_ids = {
@@ -581,22 +551,22 @@ def test_random_sampling(
581551
4008,
582552
315,
583553
419,
554+
6802,
555+
1035,
556+
387,
557+
438,
558+
11017,
559+
1447,
560+
2082,
561+
40608,
584562
3691,
585563
41189,
586-
1447,
587-
1986,
588-
6109,
589-
4419,
564+
374,
565+
11699,
566+
389,
590567
264,
591-
2613,
592-
11,
593-
3908,
594-
5562,
595-
448,
596-
10876,
597-
323,
598-
41199,
599-
18241,
568+
22360,
569+
7329,
600570
]
601571
],
602572
"wo_sampler": [
@@ -651,33 +621,17 @@ def test_guided_decoding(
651621
ctx_len: int,
652622
generation_len: int,
653623
full_batch_size: int,
654-
spec_length: int,
624+
spec_length: Optional[int],
655625
is_vlm: bool,
656626
):
657627
"""
658628
Test QPCs compiled with and without guided decoding.
659629
"""
660630
# Export and compile QEfficient models
661-
additional_configs = {}
662-
additional_params = {}
663-
if is_vlm:
664-
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
665-
config.llm_config.num_hidden_layers = 2
666-
additional_configs["config"] = config
667-
additional_configs["kv_offload"] = True
668-
additional_configs["trust_remote_code"] = True
669-
model_hf = AutoModelForCausalLM.from_pretrained(
670-
model,
671-
config=config,
672-
)
673-
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
674-
additional_params["processor"] = InternProcessor(model_hf, tokenizer)
675-
assert isinstance(prompts, tuple)
676-
additional_params["images"] = prompts[0]
677-
prompts = prompts[1]
678-
else:
679-
additional_configs["num_hidden_layers"] = 2
680-
spec_length -= 1
631+
num_hidden_layers = 2
632+
additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
633+
model, is_vlm, num_hidden_layers, prompts, spec_length
634+
)
681635
model_w_sampler_w_guided_decoding = QEFFAutoModelForCausalLM.from_pretrained(
682636
model,
683637
continuous_batching=True,
@@ -733,6 +687,10 @@ def test_guided_decoding(
733687
"min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
734688
"random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32),
735689
}
690+
if is_vlm:
691+
vocab_size = model_w_sampler_w_guided_decoding.model.language_model.config.vocab_size
692+
else:
693+
vocab_size = model_w_sampler_w_guided_decoding.model.config.vocab_size
736694
model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate(
737695
tokenizer=tokenizer,
738696
prompts=prompts,
@@ -744,7 +702,7 @@ def test_guided_decoding(
744702
**sampling_params,
745703
**{
746704
"token_bitmasks": np.tile(
747-
np.random.choice([True, False], size=(model_w_sampler_w_guided_decoding.model.config.vocab_size,)),
705+
np.random.choice([True, False], size=(vocab_size,)),
748706
(full_batch_size, 1),
749707
)
750708
},

tests/transformers/test_subfunction.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#
66
# ----------------------------------------------------------------------------
77

8+
import hashlib
9+
810
import pytest
911
import torch
1012
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -57,8 +59,14 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path):
5759
without_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=False)
5860
hash_0_1 = model_0_0.export_hash
5961

62+
# Test that the export hash changes when use_onnx_subfunction is toggled, indicating different parameters are used
6063
assert hash_0_0 != hash_0_1
6164

65+
# Test that the exported ONNX files hash are different by comparing their hashes when use_onnx_subfunction is toggled
66+
with_sub_func_onnx_hash = hashlib.sha256(open(with_sub_func_onnx, "rb").read()).hexdigest()
67+
without_sub_func_onnx_hash = hashlib.sha256(open(without_sub_func_onnx, "rb").read()).hexdigest()
68+
assert with_sub_func_onnx_hash != without_sub_func_onnx_hash
69+
6270
compile_params = {"prefill_seq_len": 8, "ctx_len": 16}
6371
model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params)
6472
generation_00 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer)

0 commit comments

Comments
 (0)