Skip to content

Commit 3e8be0e

Browse files
authored
Adjust Quantization Preprocessing on Opset Handling (#26503)
### Description <!-- Describe your changes. --> - Moved `ReplaceUpsampleWithResize` in the quantization preprocessing pipeline to occur before `SymbolicShapeInference`, due to the current lack of shape inference support for the `Upsample` operator. Prevented unnecessary modifications to `model.opset_import`. Setting `model.opset_import` prior to invoking `onnx.version_converter` can interfere with successful opset conversion. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> - This change ensures the quantization preprocessing functions correctly by addressing limitations in shape inference for `Upsample`. - It also avoids potential issues with opset conversion caused by premature modification of `model.opset_import`.
1 parent 948bba1 commit 3e8be0e

File tree

2 files changed

+175
-18
lines changed

2 files changed

+175
-18
lines changed

onnxruntime/python/tools/quantization/shape_inference.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -74,34 +74,29 @@ def quant_pre_process(
7474

7575
with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
7676
temp_path = Path(quant_tmp_dir)
77-
model = None
77+
model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
78+
79+
# Since Upsample is deprecated after opset v10, and the model's opset will
80+
# be upgraded to at least v11 during quantization, we need to replace Upsample
81+
# with Resize first to avoid generating an invalid model.
82+
ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
83+
if len(ai_onnx_domain) == 1:
84+
opset_version = ai_onnx_domain[0].version
85+
if opset_version <= 10:
86+
ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply()
87+
model = onnx.version_converter.convert_version(model, 11)
88+
model = save_and_reload_model_with_shape_infer(model)
7889

7990
if not skip_symbolic_shape:
8091
logger.info("Performing symbolic shape inference...")
81-
loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
8292
model = SymbolicShapeInference.infer_shapes(
83-
loaded_model,
93+
model,
8494
int_max,
8595
auto_merge,
8696
guess_output_rank,
8797
verbose,
8898
)
8999

90-
# Since Upsample is deprecated after opset v10, and the model's opset will
91-
# be upgraded to at least v11 during quantization, we need to replace Upsample
92-
# with Resize first to avoid generating an invalid model.
93-
if model:
94-
ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
95-
if len(ai_onnx_domain) == 1:
96-
opset_version = ai_onnx_domain[0].version
97-
if opset_version < 10:
98-
ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply()
99-
model.opset_import.remove(ai_onnx_domain[0])
100-
opset_version = 11
101-
model.opset_import.extend([onnx.helper.make_opsetid("", opset_version)])
102-
model = onnx.version_converter.convert_version(model, opset_version)
103-
model = save_and_reload_model_with_shape_infer(model)
104-
105100
if not skip_optimization:
106101
# Use ORT optimizers (native code) to optimize model
107102
if not skip_symbolic_shape:
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#!/usr/bin/env python
2+
# -------------------------------------------------------------------------
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License. See License.txt in the project root for
5+
# license information.
6+
# --------------------------------------------------------------------------
7+
8+
import tempfile
9+
import unittest
10+
from pathlib import Path
11+
12+
import numpy as np
13+
import onnx
14+
15+
from onnxruntime.quantization.shape_inference import quant_pre_process
16+
17+
18+
class TestUpsample(unittest.TestCase):
19+
def setUp(self):
20+
self.temp_dir = tempfile.TemporaryDirectory(prefix="ort.quant_preprocess_")
21+
self.temp_path = Path(self.temp_dir.name)
22+
23+
def tearDown(self):
24+
self.temp_dir.cleanup()
25+
26+
def build_upsample_model(self, input_shape=(1, 3, 32, 32)):
27+
"""
28+
Build a model with deprecated Upsample op (opset <= 10) for testing version conversion.
29+
"""
30+
input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, input_shape)
31+
output_shape = (input_shape[0], input_shape[1], input_shape[2] * 2, input_shape[3] * 2)
32+
output_tensor = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, output_shape)
33+
34+
# Create scales for upsample
35+
scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32)
36+
scales_initializer = onnx.numpy_helper.from_array(scales, "scales")
37+
38+
upsample_node = onnx.helper.make_node(
39+
"Upsample",
40+
["input", "scales"],
41+
["output"],
42+
name="upsample_node",
43+
mode="nearest",
44+
)
45+
46+
graph = onnx.helper.make_graph(
47+
[upsample_node],
48+
"upsample_graph",
49+
[input_tensor],
50+
[output_tensor],
51+
initializer=[scales_initializer],
52+
)
53+
# Use opset 10 to trigger Upsample -> Resize conversion
54+
opset_imports = [onnx.helper.make_opsetid("", 10)]
55+
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
56+
return model
57+
58+
def test_upsample_to_resize_conversion(self):
59+
"""
60+
Test that deprecated Upsample ops are converted to Resize ops.
61+
"""
62+
model = self.build_upsample_model()
63+
input_path = self.temp_path / "input_model.onnx"
64+
output_path = self.temp_path / "preprocessed_model.onnx"
65+
66+
onnx.save_model(model, input_path)
67+
68+
# Verify original model has Upsample op
69+
self.assertEqual(model.graph.node[0].op_type, "Upsample")
70+
self.assertEqual(model.opset_import[0].version, 10)
71+
72+
quant_pre_process(
73+
input_model=str(input_path),
74+
output_model_path=str(output_path),
75+
skip_optimization=True,
76+
skip_onnx_shape=True,
77+
skip_symbolic_shape=True,
78+
)
79+
80+
self.assertTrue(output_path.exists())
81+
preprocessed_model = onnx.load(str(output_path))
82+
83+
# Verify Upsample was converted to Resize and opset was upgraded
84+
node_types = [node.op_type for node in preprocessed_model.graph.node]
85+
assert "Resize" in node_types
86+
assert "Upsample" not in node_types
87+
assert preprocessed_model.opset_import[0].version >= 11
88+
89+
90+
class TestClip(unittest.TestCase):
91+
def setUp(self):
92+
self.temp_dir = tempfile.TemporaryDirectory(prefix="ort.quant_preprocess_")
93+
self.temp_path = Path(self.temp_dir.name)
94+
95+
def tearDown(self):
96+
self.temp_dir.cleanup()
97+
98+
def build_clip_model(self, input_shape=(1, 3, 32, 32)):
99+
"""
100+
Build a model with Clip op using ai.onnx v6 for testing version conversion.
101+
"""
102+
input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, input_shape)
103+
output_tensor = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, input_shape)
104+
105+
# Create min and max values for clip
106+
min_val = np.array(0.0, dtype=np.float32)
107+
max_val = np.array(6.0, dtype=np.float32)
108+
min_initializer = onnx.numpy_helper.from_array(min_val, "min")
109+
max_initializer = onnx.numpy_helper.from_array(max_val, "max")
110+
111+
clip_node = onnx.helper.make_node(
112+
"Clip",
113+
["input", "min", "max"],
114+
["output"],
115+
name="clip_node",
116+
)
117+
118+
graph = onnx.helper.make_graph(
119+
[clip_node],
120+
"clip_graph",
121+
[input_tensor],
122+
[output_tensor],
123+
initializer=[min_initializer, max_initializer],
124+
)
125+
# Use opset 6 to trigger version conversion
126+
opset_imports = [onnx.helper.make_opsetid("", 6)]
127+
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
128+
return model
129+
130+
def test_clip_version_conversion(self):
131+
"""
132+
Test that Clip op from ai.onnx v6 is upgraded to v11 after quant_pre_process.
133+
"""
134+
model = self.build_clip_model()
135+
input_path = self.temp_path / "input_clip_model.onnx"
136+
output_path = self.temp_path / "preprocessed_clip_model.onnx"
137+
138+
onnx.save_model(model, input_path)
139+
140+
# Verify original model has Clip op with opset 6
141+
self.assertEqual(model.graph.node[0].op_type, "Clip")
142+
self.assertEqual(model.opset_import[0].version, 6)
143+
144+
quant_pre_process(
145+
input_model=str(input_path),
146+
output_model_path=str(output_path),
147+
skip_optimization=True,
148+
skip_onnx_shape=True,
149+
skip_symbolic_shape=True,
150+
)
151+
152+
self.assertTrue(output_path.exists())
153+
preprocessed_model = onnx.load(str(output_path))
154+
155+
# Verify Clip op is still present and opset was upgraded to v11 or higher
156+
node_types = [node.op_type for node in preprocessed_model.graph.node]
157+
assert "Clip" in node_types
158+
assert preprocessed_model.opset_import[0].version >= 11
159+
160+
161+
if __name__ == "__main__":
162+
unittest.main()

0 commit comments

Comments
 (0)