Skip to content

Commit 5f531f5

Browse files
committed
Rebase main
1 parent e8e270f commit 5f531f5

File tree

4 files changed

+34
-67
lines changed

4 files changed

+34
-67
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,6 @@ struct ZeroInsertionResult {
6262
bool trimmedTail;
6363
};
6464

65-
static std::pair<Value, Value>
66-
getOrCreateConvZeroPoints(PatternRewriter &rewriter, Location loc, Value input,
67-
Type inputElemTy, Value weight, Type weightElemTy) {
68-
auto zps = tosa::createZPsAsConst(rewriter, input, weight);
69-
Value inputZp = zps.first;
70-
if (!inputZp)
71-
inputZp =
72-
tosa::createZeroPointTensor(rewriter, loc, inputElemTy, 0).value();
73-
Value weightZp = zps.second;
74-
if (!weightZp)
75-
weightZp =
76-
tosa::createZeroPointTensor(rewriter, loc, weightElemTy, 0).value();
77-
return {inputZp, weightZp};
78-
}
79-
8065
static FailureOr<ZeroInsertionResult>
8166
insertZerosAlongAxis(Value input, int axis, int64_t stride,
8267
ConversionPatternRewriter &rewriter, Location loc) {
@@ -2650,6 +2635,18 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
26502635
return rewriter.notifyMatchFailure(
26512636
op, "failed to get accumulator type for convolution ops");
26522637

2638+
// Get zero-points for input and weight
2639+
// We need to get input/weights from the op instead of adaptor so that the
2640+
// connection to quantization detail carrying ops are preserved
2641+
auto inputZp =
2642+
tosa::getZeroPointValue(rewriter, op, op.getInput(), inputElemTy);
2643+
auto weightZp =
2644+
tosa::getZeroPointValue(rewriter, op, op.getWeight(), weightElemTy);
2645+
2646+
if (failed(inputZp) || failed(weightZp)) {
2647+
return rewriter.notifyMatchFailure(op, "failed to get zero point values");
2648+
}
2649+
26532650
// TOSA works in NHWC (2D) / NDHWC (3D) and takes OHWI / ODHWI weights for
26542651
// convolution. Perform the necessary transformations.
26552652
SmallVector<int32_t, 5> torchToTosaDims;
@@ -2814,13 +2811,10 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
28142811
auto convOpTy = RankedTensorType::get(
28152812
makeShapeLLVMCompatible(outTosaShape), biasElemTy);
28162813

2817-
auto [inputZp, weightZp] = getOrCreateConvZeroPoints(
2818-
rewriter, loc, input, inputElemTy, weight, weightElemTy);
2819-
28202814
auto convResult =
28212815
tosa::Conv3DOp::create(
28222816
rewriter, loc, getTypeConverter()->convertType(convOpTy),
2823-
paddedInput, flippedWeight, bias, inputZp, weightZp,
2817+
paddedInput, flippedWeight, bias, *inputZp, *weightZp,
28242818
rewriter.getDenseI64ArrayAttr({0, 0, 0, 0, 0, 0}),
28252819
rewriter.getDenseI64ArrayAttr({1, 1, 1}),
28262820
rewriter.getDenseI64ArrayAttr(dilation), accType)
@@ -2837,14 +2831,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
28372831
rewriter.getDenseI32ArrayAttr(tosaToTorchDims))
28382832
.getResult();
28392833

2840-
Value rescaledResult = transposedOutput;
2841-
if (isa<quant::QuantizedType>(inputElemTy)) {
2842-
rescaledResult = tosa::buildRescaleOpConvOutput(
2843-
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2844-
}
2845-
28462834
rewriter.replaceOp(
2847-
op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy)
2835+
op, {tosa::tosaCastTensorToType(rewriter, transposedOutput, outputTy)
28482836
.value()});
28492837
return success();
28502838
}
@@ -2892,16 +2880,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
28922880
auto transConvOpTy = RankedTensorType::get(
28932881
makeShapeLLVMCompatible(outTosaShape), biasElemTy);
28942882

2895-
// Zero-points.
2896-
auto inputZp =
2897-
tosa::getZeroPointValue(rewriter, op, op.getInput(), inputElemTy);
2898-
auto weightZp =
2899-
tosa::getZeroPointValue(rewriter, op, op.getWeight(), weightElemTy);
2900-
2901-
if (failed(inputZp) || failed(weightZp)) {
2902-
return rewriter.notifyMatchFailure(op, "failed to get zero point values");
2903-
}
2904-
29052883
Value convTOut = tosa::TransposeConv2DOp::create(
29062884
rewriter, op->getLoc(),
29072885
getTypeConverter()->convertType(transConvOpTy),
@@ -3116,17 +3094,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
31163094
outputCDim};
31173095
}
31183096

3119-
// create zero-point tensors for input and weight
3120-
// Zero-points.
3121-
auto inputZp =
3122-
tosa::getZeroPointValue(rewriter, op, op.getInput(), inputElemTy);
3123-
auto weightZp =
3124-
tosa::getZeroPointValue(rewriter, op, op.getWeight(), weightElemTy);
3125-
3126-
if (failed(inputZp) || failed(weightZp)) {
3127-
return rewriter.notifyMatchFailure(op, "failed to get zero point values");
3128-
}
3129-
31303097
Value bias;
31313098
Type biasElemTy;
31323099
Value convOpResult;

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,10 +3711,10 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4
37113711
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[KERNEL]], %[[KERNEL]] : (!torch.int, !torch.int) -> !torch.list<int>
37123712
// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
37133713
// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3714-
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32>
3715-
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32>
37163714
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
37173715
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3716+
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32>
3717+
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32>
37183718
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32>
37193719
// CHECK: %[[CONV:.*]] = tosa.conv2d %[[NHWC_INPUT]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 3, 3, 3, 3>, stride = array<i64: 1, 1>} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32>
37203720
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<5x14x24x10xf32>) -> tensor<5x10x14x24xf32>
@@ -3750,12 +3750,12 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>)
37503750
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[KERNEL]], %[[KERNEL]] : (!torch.int, !torch.int) -> !torch.list<int>
37513751
// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[KERNEL]], %[[KERNEL]] : (!torch.int, !torch.int) -> !torch.list<int>
37523752
// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3753+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3754+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
37533755
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32>
37543756
// CHECK: %[[HW_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 2, 3, 0, 1>} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32>
37553757
// CHECK: %[[RESHAPE_SHAPE:.*]] = tosa.const_shape {values = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
37563758
// CHECK: %[[FILTER:.*]] = tosa.reshape %[[HW_WEIGHT]], %[[RESHAPE_SHAPE]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32>
3757-
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3758-
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
37593759
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32>
37603760
// CHECK: %[[DEPTHWISE:.*]] = tosa.depthwise_conv2d %[[NHWC_INPUT]], %[[FILTER]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array<i64: 3, 3>, pad = array<i64: 3, 2, 3, 2>, stride = array<i64: 2, 2>} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32>
37613761
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[DEPTHWISE]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<5x5x10x4xf32>) -> tensor<5x4x5x10xf32>
@@ -3790,6 +3790,8 @@ func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f3
37903790
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
37913791
// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
37923792
// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3793+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3794+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
37933795
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32>
37943796
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32>
37953797
// CHECK-DAG: %[[SLICE0_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
@@ -3798,8 +3800,6 @@ func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f3
37983800
// CHECK-DAG: %[[SLICE1_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
37993801
// CHECK-DAG: %[[SLICE1_SIZE:.*]] = tosa.const_shape {values = dense<[1, 55, 55, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
38003802
// CHECK: %[[TRIMMED_HW:.*]] = tosa.slice %[[TRIMMED_H]], %[[SLICE1_START]], %[[SLICE1_SIZE]] : (tensor<1x55x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x55x64xf32>
3801-
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3802-
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
38033803
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32>
38043804
// CHECK: %[[CONV:.*]] = tosa.conv2d %[[TRIMMED_HW]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32>
38053805
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x28x28x128xf32>) -> tensor<1x128x28x28xf32>
@@ -3835,10 +3835,10 @@ func.func @torch.aten.convolution$zero_pad_with_sliced_input(%arg0: !torch.vtens
38353835
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
38363836
// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
38373837
// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3838-
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32>
3839-
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
38403838
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
38413839
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3840+
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32>
3841+
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
38423842
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
38433843
// CHECK: %[[CONV:.*]] = tosa.conv2d %[[NHWC_INPUT]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32>
38443844
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x112x112x32xf32>) -> tensor<1x32x112x112xf32>
@@ -3873,6 +3873,8 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_
38733873
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
38743874
// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
38753875
// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3876+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3877+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
38763878
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32>
38773879
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
38783880
// CHECK-DAG: %[[SLICE0_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
@@ -3881,8 +3883,6 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_
38813883
// CHECK-DAG: %[[SLICE1_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
38823884
// CHECK-DAG: %[[SLICE1_SIZE:.*]] = tosa.const_shape {values = dense<[1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
38833885
// CHECK: %[[TRIMMED_HW:.*]] = tosa.slice %[[TRIMMED_H]], %[[SLICE1_START]], %[[SLICE1_SIZE]] : (tensor<1x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x224x3xf32>
3884-
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3885-
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
38863886
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
38873887
// CHECK: %[[CONV:.*]] = tosa.conv2d %[[TRIMMED_HW]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32>
38883888
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x75x75x32xf32>) -> tensor<1x32x75x75xf32>
@@ -3918,10 +3918,10 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
39183918
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
39193919
// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
39203920
// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3921-
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3922-
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
39233921
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
39243922
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3923+
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3924+
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
39253925
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
39263926
// CHECK: %[[CONV:.*]] = tosa.conv2d %[[NHWC_INPUT]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
39273927
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
@@ -3958,6 +3958,8 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_
39583958
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
39593959
// CHECK: %[[DILATION_LIST:.*]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
39603960
// CHECK: %[[OUTPUT_PAD:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3961+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3962+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
39613963
// CHECK: %[[NHWC_INPUT:.*]] = tosa.transpose %[[INPUT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
39623964
// CHECK: %[[NHWC_WEIGHT:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
39633965
// CHECK-DAG: %[[SLICE0_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
@@ -3966,8 +3968,6 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_
39663968
// CHECK-DAG: %[[SLICE1_START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
39673969
// CHECK-DAG: %[[SLICE1_SIZE:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
39683970
// CHECK: %[[TRIMMED_HW:.*]] = tosa.slice %[[TRIMMED_H]], %[[SLICE1_START]], %[[SLICE1_SIZE]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3969-
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3970-
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
39713971
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
39723972
// CHECK: %[[CONV:.*]] = tosa.conv2d %[[TRIMMED_HW]], %[[NHWC_WEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
39733973
// CHECK: %[[RESULT_NCHW:.*]] = tosa.transpose %[[CONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>

0 commit comments

Comments
 (0)