Skip to content

Commit e8e270f

Browse files
committed
[tosa] : Extend quantization support in tosa backend.
1 parent 837e498 commit e8e270f

File tree

9 files changed

+366
-230
lines changed

9 files changed

+366
-230
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
3535
Value input_val, double input_scale,
3636
int64_t input_zp);
3737

38-
// Creates a TOSA rescale op based on conv2d parameters.
39-
Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
40-
Value conv_val, ShapedType input_type,
41-
ShapedType weight_type, ShapedType output_type);
42-
4338
// Check if scale32 mode is used for given output_element_type
4439
bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
4540

@@ -114,6 +109,13 @@ Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
114109
Operation *op, Value inputNHWC,
115110
ArrayRef<int64_t> padExtents);
116111

112+
// Get the zero point from a torch.tensor or torch.qtensor value.
113+
// If the value is a quantized tensor, it extracts the zero point as a
114+
// scalar integer value. If the value is a float tensor, it returns a
115+
// constant 0.
116+
FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
117+
Value tensor, Type elemType);
118+
117119
} // namespace tosa
118120
} // namespace mlir
119121

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
119119

120120
void getZeroPoint(Value value, Value &zeropoint);
121121

122+
LogicalResult getQuantizationParams(Value value, Value &zeropoint, Value &scale,
123+
int64_t &axis);
124+
122125
} // namespace Torch
123126
} // namespace torch
124127
} // namespace mlir

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 177 additions & 99 deletions
Large diffs are not rendered by default.

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 28 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
1212
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1313
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
14+
#include "torch-mlir/Conversion/Utils/Utils.h"
15+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1416
#include "llvm/ADT/ArrayRef.h"
1517

1618
namespace mlir {
@@ -91,120 +93,6 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
9193
input_zp, 0, tosa::RoundingMode::SINGLE_ROUND, true);
9294
}
9395

94-
// Creates a TOSA rescale op based on conv2d parameters.
95-
Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
96-
Value conv_val, ShapedType input_type,
97-
ShapedType weight_type, ShapedType output_type) {
98-
auto input_qtype =
99-
dyn_cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
100-
auto output_qtype =
101-
dyn_cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
102-
103-
double input_scale = input_qtype.getScale();
104-
105-
int64_t output_zp = output_qtype.getZeroPoint();
106-
double output_scale = output_qtype.getScale();
107-
108-
bool scale32 = isScale32(output_qtype);
109-
int32_t scale_width = scale32 ? 32 : 16;
110-
111-
bool input_unsigned = input_qtype.isUnsignedInteger();
112-
bool output_unsigned = output_qtype.isUnsignedInteger();
113-
114-
const auto input_zp_val = tosa::createZeroPointTensor(
115-
rewriter, op->getLoc(), input_type, static_cast<int64_t>(0));
116-
if (!input_zp_val.has_value())
117-
op->emitError("Failed to create input zero-point tensor for RescaleOp.");
118-
119-
const auto output_zp_val = tosa::createZeroPointTensor(
120-
rewriter, op->getLoc(), output_type, output_zp);
121-
if (!output_zp_val.has_value())
122-
op->emitError("Failed to create output zero-point tensor for RescaleOp.");
123-
124-
if (auto weight_per_tensor_qtype =
125-
dyn_cast<mlir::quant::UniformQuantizedType>(
126-
weight_type.getElementType())) {
127-
// Per-tensor quantization
128-
double weight_scale = weight_per_tensor_qtype.getScale();
129-
130-
int32_t multiplier;
131-
int32_t shift;
132-
133-
double op_tensor_scale = (input_scale * weight_scale) / output_scale;
134-
135-
if (!computeMultiplierAndShift(op_tensor_scale, multiplier, shift,
136-
scale_width))
137-
op->emitError(
138-
"buildRescaleOpConvOutput: shift must be in the range 2 <= shift <= "
139-
"62");
140-
141-
Value multiplier_val =
142-
buildRescaleMultiplier(scale32, rewriter, op, {multiplier});
143-
auto shift_val = tosa::getConstTensor<int8_t>(
144-
rewriter, op, {static_cast<int8_t>(shift)}, {1})
145-
.value();
146-
147-
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
148-
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
149-
shift_val, input_zp_val.value(), output_zp_val.value(),
150-
rewriter.getBoolAttr(scale32),
151-
tosa::RoundingModeAttr::get(rewriter.getContext(),
152-
tosa::RoundingMode::DOUBLE_ROUND),
153-
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
154-
rewriter.getBoolAttr(output_unsigned));
155-
156-
return rescale_op.getResult();
157-
158-
} else if (auto weight_per_channel_qtype =
159-
dyn_cast<mlir::quant::UniformQuantizedPerAxisType>(
160-
weight_type.getElementType())) {
161-
// Per-channel quantization
162-
SmallVector<int32_t> multiplier_arr;
163-
SmallVector<int8_t> shift_arr;
164-
165-
SmallVector<double> weight_scale_arr(
166-
weight_per_channel_qtype.getScales().begin(),
167-
weight_per_channel_qtype.getScales().end());
168-
169-
for (double weight_scale : weight_scale_arr) {
170-
int32_t multiplier;
171-
int32_t shift;
172-
173-
double op_channel_scale = (input_scale * weight_scale) / output_scale;
174-
175-
if (!computeMultiplierAndShift(op_channel_scale, multiplier, shift, 32))
176-
op->emitError(
177-
"buildRescaleOpConvOutput: shift must be in the range 2 <= shift "
178-
"<= 62");
179-
180-
multiplier_arr.push_back(multiplier);
181-
shift_arr.push_back(static_cast<int8_t>(shift));
182-
}
183-
184-
Value multiplier_val =
185-
buildRescaleMultiplier(scale32, rewriter, op, multiplier_arr);
186-
auto shift_val =
187-
tosa::getConstTensor<int8_t>(rewriter, op, shift_arr,
188-
{static_cast<int64_t>(shift_arr.size())})
189-
.value();
190-
191-
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
192-
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
193-
shift_val, input_zp_val.value(), output_zp_val.value(),
194-
rewriter.getBoolAttr(scale32),
195-
tosa::RoundingModeAttr::get(rewriter.getContext(),
196-
tosa::RoundingMode::DOUBLE_ROUND),
197-
rewriter.getBoolAttr(true), rewriter.getBoolAttr(input_unsigned),
198-
rewriter.getBoolAttr(output_unsigned));
199-
200-
return rescale_op.getResult();
201-
202-
} else {
203-
op->emitOpError("buildConvRescaleOp: unknown weight quantized type");
204-
return nullptr;
205-
}
206-
}
207-
20896
// Check if scale32 mode is used for given output_element_type
20997
bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
21098
return (output_element_type.getStorageTypeIntegralWidth() == 8);
@@ -666,5 +554,31 @@ Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
666554
.getResult();
667555
}
668556

557+
FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
558+
Value tensor, Type elemType) {
559+
Location loc = op->getLoc();
560+
561+
Value zp;
562+
// Torch::getZeroPoint looks at the defining op of `tensor` to find
563+
// the quantization parameters.
564+
torch::Torch::getZeroPoint(tensor, zp);
565+
566+
if (!zp) {
567+
// Initialize zero constant values as zero-points, if the input tensor isn't
568+
// quantized
569+
zp = tosa::createZeroPointTensor(rewriter, loc, elemType, 0).value();
570+
} else {
571+
572+
int64_t zpConst;
573+
if (!matchPattern(zp, torch::Torch::m_TorchConstantInt(&zpConst)))
574+
return rewriter.notifyMatchFailure(
575+
op, "zero point must be a scalar constant");
576+
577+
zp = tosa::createZeroPointTensor(rewriter, loc, elemType, zpConst).value();
578+
}
579+
580+
return zp;
581+
}
582+
669583
} // namespace tosa
670584
} // namespace mlir

lib/Conversion/Utils/Utils.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1818
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
19+
#include "llvm/ADT/TypeSwitch.h"
1920

2021
namespace mlir {
2122
namespace torch {
@@ -566,9 +567,45 @@ FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
566567
}
567568

568569
void getZeroPoint(Value value, Value &zeropoint) {
569-
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
570-
zeropoint = make.getZeroPoint();
571-
}
570+
Operation *definingOp = value.getDefiningOp();
571+
if (!definingOp)
572+
return;
573+
574+
// Extract and set the zero point from a given op.
575+
auto getZp = [&](auto op) { zeropoint = op.getZeroPoint(); };
576+
577+
llvm::TypeSwitch<Operation *>(definingOp)
578+
.Case<Aten_MakePerTensorQuantizedTensorOp>(getZp)
579+
.Case<AtenQuantizePerTensorOp>(getZp)
580+
.Case<Aten_MakePerChannelQuantizedTensorOp>(getZp);
581+
}
582+
583+
LogicalResult getQuantizationParams(Value value, Value &zeropoint, Value &scale,
584+
int64_t &axis) {
585+
Operation *definingOp = value.getDefiningOp();
586+
if (!definingOp)
587+
return failure();
588+
589+
// Extract and set the common parameters from a given op.
590+
auto setParams = [&](auto op) -> LogicalResult {
591+
zeropoint = op.getZeroPoint();
592+
scale = op.getScale();
593+
// Axis must be constant scalar int for Aten_MakePerChannelQuantizedTensorOp
594+
if constexpr (std::is_same_v<decltype(op),
595+
Aten_MakePerChannelQuantizedTensorOp>) {
596+
return success(matchPattern(op.getAxis(), m_TorchConstantInt(&axis)));
597+
} else {
598+
// Other ops don't have axis parameter
599+
axis = -1;
600+
return success();
601+
}
602+
};
603+
604+
return llvm::TypeSwitch<Operation *, LogicalResult>(definingOp)
605+
.Case<Aten_MakePerTensorQuantizedTensorOp>(setParams)
606+
.Case<AtenQuantizePerTensorOp>(setParams)
607+
.Case<Aten_MakePerChannelQuantizedTensorOp>(setParams)
608+
.Default([](auto) { return failure(); });
572609
}
573610

574611
} // namespace Torch

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
115115
void TorchConversion::createTorchBackendToTosaBackendPipeline(
116116
OpPassManager &pm,
117117
const TorchConversion::TosaBackendPipelineOptions &options) {
118+
119+
// We want to fuse quantized operations together before lowering to tosa.
120+
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
121+
118122
pm.addNestedPass<func::FuncOp>(
119123
createConvertTorchToTosaPass(options.requireFullTosaConversion));
120124
// Fold full-layer operations on TOSA constants

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3566,9 +3566,6 @@
35663566
"ViewDtypeStaticModule_basic",
35673567
"Unfold_Module_Rank_Zero_Size_Zero_basic",
35683568
"ArangeZeroElementOutputModule_basic",
3569-
"SliceOutOfUpperBoundIndexModule_basic",
3570-
"SliceOutOfUpperBoundIndexStaticModule_basic",
3571-
"SliceStartEqEndModule_basic",
35723569
"ElementwiseCreateComplexModule_basic",
35733570
"AtenPolarDoubleModule_basic",
35743571
"AtenPolarFloatModule_basic",
@@ -3681,8 +3678,6 @@
36813678
"Conv1dWithValidPaddingModule_basic",
36823679
"Conv1dGroupModule_basic",
36833680
"Conv2dQInt8Module_grouped",
3684-
"Conv2dQInt8PerChannelModule_basic",
3685-
"Conv2dQInt8PerChannelModule_depthwise",
36863681
"Conv2dQInt8PerChannelModule_grouped",
36873682
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
36883683
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
@@ -3903,7 +3898,6 @@
39033898
"SignAndLogarithmOfDeterminantDynamicModule_F32",
39043899
"SliceStaticComplexInputModule_basic",
39053900
"SliceCopyStartGreaterThanDimSize_Module_basic",
3906-
"SliceEndSleStartModule_basic",
39073901
"SliceOutOfLowerBoundEndIndexModule_basic",
39083902
"SortIntListReverse_basic",
39093903
"SortIntList_basic",

test/Conversion/TorchToTosa/quantization.mlir

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
// CHECK-LABEL: func.func @AtenMmQint8(
66
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[3,4],si8>,
77
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[4,3],si8>) -> !torch.vtensor<[3,3],f32> {
8-
// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
9-
// CHECK: %[[OUT_SCALE:.*]] = "tosa.const"() <{values = dense<3.784000e-04> : tensor<3x3xf32>}> : () -> tensor<3x3xf32>
8+
// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
9+
// CHECK-DAG: %[[OUT_SCALE:.*]] = "tosa.const"() <{values = dense<3.784000e-04> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
1010
// CHECK-DAG: %[[MUL_OUT_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<2xindex>} : () -> !tosa.shape<2>
1111
// CHECK-DAG: %[[RHS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
1212
// CHECK-DAG: %[[LHS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
@@ -19,7 +19,7 @@
1919
// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[LHS_RESHAPED]], %[[RHS_RESHAPED]], %[[LHS_ZP]], %[[RHS_ZP]] : (tensor<1x3x4xi8>, tensor<1x4x3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x3x3xi32>
2020
// CHECK: %[[MATMUL_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[MUL_OUT_SHAPE]] : (tensor<1x3x3xi32>, !tosa.shape<2>) -> tensor<3x3xi32>
2121
// CHECK: %[[MATMUL_FP32:.*]] = tosa.cast %[[MATMUL_RESHAPE]] : (tensor<3x3xi32>) -> tensor<3x3xf32>
22-
// CHECK: %[[OUT_SCALED:.*]] = tosa.mul %[[MATMUL_FP32]], %[[OUT_SCALE]], %[[SHIFT]] : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi8>) -> tensor<3x3xf32>
22+
// CHECK: %[[OUT_SCALED:.*]] = tosa.mul %[[MATMUL_FP32]], %[[OUT_SCALE]], %[[SHIFT]] : (tensor<3x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x3xf32>
2323
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[OUT_SCALED]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
2424
// CHECK: return %[[RES]]
2525
func.func @AtenMmQint8(%arg0: !torch.vtensor<[3,4],si8>, %arg1: !torch.vtensor<[4,3],si8>) -> !torch.vtensor<[3,3],f32>
@@ -76,3 +76,65 @@ func.func @quantization_per_tensor(%arg0: !torch.vtensor<[2,4,4],f32>) -> !torch
7676
%0 = torch.aten.quantize_per_tensor %arg0, %scale, %zp, %dtype : !torch.vtensor<[2,4,4],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],!torch.qint8>
7777
return %0 : !torch.vtensor<[2,4,4],!torch.qint8>
7878
}
79+
80+
81+
// -----
82+
// CHECK-LABEL: func.func @dequantize.self(
83+
// CHECK-SAME: %[[IN:.*]]: !torch.vtensor<[3,4,3,2],si8>,
84+
// CHECK-SAME: %[[SCALE:.*]]: !torch.vtensor<[3],f32>,
85+
// CHECK-SAME: %[[ZP:.*]]: !torch.vtensor<[3],si8>) -> !torch.vtensor<[3,4,3,2],f32> {
86+
// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
87+
// CHECK: %[[QUANT_PARAM_SHAPE:.*]] = tosa.const_shape {values = dense<[3, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
88+
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[IN]] : !torch.vtensor<[3,4,3,2],si8> -> tensor<3x4x3x2xi8>
89+
// CHECK: %[[IN_I32:.*]] = tosa.cast %[[IN_TENSOR]] : (tensor<3x4x3x2xi8>) -> tensor<3x4x3x2xi32>
90+
// CHECK: %[[ZP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ZP]] : !torch.vtensor<[3],si8> -> tensor<3xi8>
91+
// CHECK: %[[ZP_I32:.*]] = tosa.cast %[[ZP_TENSOR]] : (tensor<3xi8>) -> tensor<3xi32>
92+
// CHECK: %[[ZP_RESHAPED:.*]] = tosa.reshape %[[ZP_I32]], %[[QUANT_PARAM_SHAPE]] : (tensor<3xi32>, !tosa.shape<4>) -> tensor<3x1x1x1xi32>
93+
// CHECK: %[[SUB:.*]] = tosa.sub %[[IN_I32]], %[[ZP_RESHAPED]] : (tensor<3x4x3x2xi32>, tensor<3x1x1x1xi32>) -> tensor<3x4x3x2xi32>
94+
// CHECK: %[[SUB_CAST:.*]] = tosa.cast %[[SUB]] : (tensor<3x4x3x2xi32>) -> tensor<3x4x3x2xf32>
95+
// CHECK: %[[SCALE_TENSOR:.*]] = torch_c.to_builtin_tensor %[[SCALE]] : !torch.vtensor<[3],f32> -> tensor<3xf32>
96+
// CHECK: %[[SCALE_RESHAPED:.*]] = tosa.reshape %[[SCALE_TENSOR]], %[[QUANT_PARAM_SHAPE]] : (tensor<3xf32>, !tosa.shape<4>) -> tensor<3x1x1x1xf32>
97+
// CHECK: %[[MUL:.*]] = tosa.mul %[[SUB_CAST]], %[[SCALE_RESHAPED]], %[[MUL_SHIFT]] : (tensor<3x4x3x2xf32>, tensor<3x1x1x1xf32>, tensor<1xi8>) -> tensor<3x4x3x2xf32>
98+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[MUL]]
99+
func.func @dequantize.self(%arg0: !torch.vtensor<[3,4,3,2],si8>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],si8>) -> !torch.vtensor<[3,4,3,2],f32> {
100+
%int0 = torch.constant.int 0
101+
%0 = torch.aten._make_per_channel_quantized_tensor %arg0, %arg1, %arg2, %int0 : !torch.vtensor<[3,4,3,2],si8>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si8>, !torch.int -> !torch.vtensor<[3,4,3,2],!torch.qint8>
102+
%1 = torch.aten.dequantize.self %0 : !torch.vtensor<[3,4,3,2],!torch.qint8> -> !torch.vtensor<[3,4,3,2],f32>
103+
return %1 : !torch.vtensor<[3,4,3,2],f32>
104+
}
105+
106+
107+
// -----
108+
// CHECK-LABEL: func.func @quantized_conv(
109+
// CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
110+
// CHECK: %[[IN_ZP:.*]] = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
111+
// CHECK: %[[CONV:.*]] = tosa.conv2d
112+
// CHECK-SAME: %[[IN_ZP]], %[[WTS_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x7x8x4xi8>, tensor<3x3x2x4xi8>, tensor<?xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x5x7x3xi32>
113+
// CHECK-NOT: torch.aten.quantize_per_tensor
114+
// CHECK-NOT: torch.aten.dequantize.self
115+
// CHECK-NOT: torch.aten._make_per_tensor_quantized_tensor
116+
// CHECK-NOT: torch.aten.dequantize.tensor
117+
118+
func.func @quantized_conv(%arg0: !torch.vtensor<[?,4,7,8],si8>, %arg1: !torch.vtensor<[3,4,3,2],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,3,5,7],f32> {
119+
%false = torch.constant.bool false
120+
%int1 = torch.constant.int 1
121+
%int0 = torch.constant.int 0
122+
%float1.000000e-04 = torch.constant.float 1.000000e-04
123+
%int3 = torch.constant.int 3
124+
%int7 = torch.constant.int 7
125+
%float1.000000e-02 = torch.constant.float 1.000000e-02
126+
%int14 = torch.constant.int 14
127+
%0 = torch.aten.quantize_per_tensor %arg2, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32>
128+
%1 = torch.aten.dequantize.self %0 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],f32>
129+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
130+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
131+
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
132+
%5 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float1.000000e-02, %int7 : !torch.vtensor<[?,4,7,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,4,7,8],!torch.qint8>
133+
%6 = torch.aten._make_per_tensor_quantized_tensor %arg1, %float1.000000e-02, %int3 : !torch.vtensor<[3,4,3,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,4,3,2],!torch.qint8>
134+
%7 = torch.aten.quantize_per_tensor %1, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32>
135+
%8 = torch.aten.int_repr %7 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],si32>
136+
%9 = torch.aten.convolution %5, %6, %8, %2, %3, %2, %false, %4, %int1 : !torch.vtensor<[?,4,7,8],!torch.qint8>, !torch.vtensor<[3,4,3,2],!torch.qint8>, !torch.vtensor<[?],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,3,5,7],si32>
137+
%10 = torch.aten._make_per_tensor_quantized_tensor %9, %float1.000000e-04, %int0 : !torch.vtensor<[?,3,5,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[?,3,5,7],!torch.qint32>
138+
%11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,3,5,7],!torch.qint32> -> !torch.vtensor<[?,3,5,7],f32>
139+
return %11 : !torch.vtensor<[?,3,5,7],f32>
140+
}

0 commit comments

Comments
 (0)