From a94f560b7991d32d443b85257f990db886aac60a Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Fri, 26 Jun 2026 11:35:52 +0800 Subject: [PATCH] Fixing quant rnage --- .github/workflows/pull.yml | 3 +-- .github/workflows/trunk.yml | 3 +-- backends/qualcomm/quantizer/quantizer.py | 26 +++++++++++++++++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 6f7d9f8589f..eb1a414531b 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -852,8 +852,7 @@ jobs: strategy: matrix: dtype: [fp32] - # TODO(T12345): re-enable qnn_16a16w once OOM on linux.2xlarge is resolved - pt2e_quantize: [qnn_8a8w] + pt2e_quantize: [qnn_16a16w, qnn_8a8w] mode: [qnn] fail-fast: false with: diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 7ded9e4cecc..7604ca474b0 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -951,8 +951,7 @@ jobs: strategy: matrix: dtype: [fp32] - # TODO(T12345): re-enable qnn_16a16w once OOM on linux.2xlarge is resolved - pt2e_quantize: [qnn_8a8w] + pt2e_quantize: [qnn_16a16w, qnn_8a8w] mode: [qnn] fail-fast: false with: diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index c6fbc51484f..6077aa03170 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -418,6 +418,27 @@ def _get_quant_range(self, node): ) return quant_range + def _get_input_quant_range(self, user_node, input_node): + """Return the quant range of the spec assigned to `input_node` in + `user_node.meta[quantization_annotation].input_qspec_map`. Falls back + to None if no concrete spec is registered for this input — needed + when the user's output_qspec is a SharedQuantizationSpec that hides + the dtype/qmin/qmax.""" + quant_info = user_node.meta.get(QCOM_QUANT_ANNOTATION_KEY, None) + if quant_info is None: + return + qspec = getattr(quant_info, "input_qspec_map", {}).get(input_node) + if qspec is None: + return + try: + dtype_info = torch.iinfo(qspec.dtype) + except: + return + return ( + (dtype_info.max if qspec.quant_max is None else qspec.quant_max) + - (dtype_info.min if qspec.quant_min is None else qspec.quant_min) + ) + def _get_candidates_with_infinity_args(self, graph_module: GraphModule): binary_op_sources = [ operator.add, @@ -473,7 +494,10 @@ def _replace_inf(self, graph_module: GraphModule) -> GraphModule: quant_min, quant_max = float("inf"), float("-inf") for source_node in node.users: - if quant_range := self._get_quant_range(source_node): + if quant_range := self._get_input_quant_range(source_node, node): + quant_min = min(quant_min, -quant_range) + quant_max = max(quant_max, quant_range) + elif quant_range := self._get_quant_range(source_node): quant_min = min(quant_min, -quant_range) quant_max = max(quant_max, quant_range)