|
5 | 5 | // CHECK-LABEL: func.func @AtenMmQint8( |
6 | 6 | // CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[3,4],si8>, |
7 | 7 | // CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[4,3],si8>) -> !torch.vtensor<[3,3],f32> { |
8 | | -// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> |
9 | | -// CHECK: %[[OUT_SCALE:.*]] = "tosa.const"() <{values = dense<3.784000e-04> : tensor<3x3xf32>}> : () -> tensor<3x3xf32> |
| 8 | +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> |
| 9 | +// CHECK-DAG: %[[OUT_SCALE:.*]] = "tosa.const"() <{values = dense<3.784000e-04> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> |
10 | 10 | // CHECK-DAG: %[[MUL_OUT_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<2xindex>} : () -> !tosa.shape<2> |
11 | 11 | // CHECK-DAG: %[[RHS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> |
12 | 12 | // CHECK-DAG: %[[LHS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> |
|
19 | 19 | // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[LHS_RESHAPED]], %[[RHS_RESHAPED]], %[[LHS_ZP]], %[[RHS_ZP]] : (tensor<1x3x4xi8>, tensor<1x4x3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x3x3xi32> |
20 | 20 | // CHECK: %[[MATMUL_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[MUL_OUT_SHAPE]] : (tensor<1x3x3xi32>, !tosa.shape<2>) -> tensor<3x3xi32> |
21 | 21 | // CHECK: %[[MATMUL_FP32:.*]] = tosa.cast %[[MATMUL_RESHAPE]] : (tensor<3x3xi32>) -> tensor<3x3xf32> |
22 | | -// CHECK: %[[OUT_SCALED:.*]] = tosa.mul %[[MATMUL_FP32]], %[[OUT_SCALE]], %[[SHIFT]] : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi8>) -> tensor<3x3xf32> |
| 22 | +// CHECK: %[[OUT_SCALED:.*]] = tosa.mul %[[MATMUL_FP32]], %[[OUT_SCALE]], %[[SHIFT]] : (tensor<3x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x3xf32> |
23 | 23 | // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[OUT_SCALED]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> |
24 | 24 | // CHECK: return %[[RES]] |
25 | 25 | func.func @AtenMmQint8(%arg0: !torch.vtensor<[3,4],si8>, %arg1: !torch.vtensor<[4,3],si8>) -> !torch.vtensor<[3,3],f32> |
@@ -76,3 +76,65 @@ func.func @quantization_per_tensor(%arg0: !torch.vtensor<[2,4,4],f32>) -> !torch |
76 | 76 | %0 = torch.aten.quantize_per_tensor %arg0, %scale, %zp, %dtype : !torch.vtensor<[2,4,4],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],!torch.qint8> |
77 | 77 | return %0 : !torch.vtensor<[2,4,4],!torch.qint8> |
78 | 78 | } |
| 79 | + |
| 80 | + |
| 81 | +// ----- |
| 82 | +// CHECK-LABEL: func.func @dequantize.self( |
| 83 | +// CHECK-SAME: %[[IN:.*]]: !torch.vtensor<[3,4,3,2],si8>, |
| 84 | +// CHECK-SAME: %[[SCALE:.*]]: !torch.vtensor<[3],f32>, |
| 85 | +// CHECK-SAME: %[[ZP:.*]]: !torch.vtensor<[3],si8>) -> !torch.vtensor<[3,4,3,2],f32> { |
| 86 | +// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> |
| 87 | +// CHECK: %[[QUANT_PARAM_SHAPE:.*]] = tosa.const_shape {values = dense<[3, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> |
| 88 | +// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[IN]] : !torch.vtensor<[3,4,3,2],si8> -> tensor<3x4x3x2xi8> |
| 89 | +// CHECK: %[[IN_I32:.*]] = tosa.cast %[[IN_TENSOR]] : (tensor<3x4x3x2xi8>) -> tensor<3x4x3x2xi32> |
| 90 | +// CHECK: %[[ZP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ZP]] : !torch.vtensor<[3],si8> -> tensor<3xi8> |
| 91 | +// CHECK: %[[ZP_I32:.*]] = tosa.cast %[[ZP_TENSOR]] : (tensor<3xi8>) -> tensor<3xi32> |
| 92 | +// CHECK: %[[ZP_RESHAPED:.*]] = tosa.reshape %[[ZP_I32]], %[[QUANT_PARAM_SHAPE]] : (tensor<3xi32>, !tosa.shape<4>) -> tensor<3x1x1x1xi32> |
| 93 | +// CHECK: %[[SUB:.*]] = tosa.sub %[[IN_I32]], %[[ZP_RESHAPED]] : (tensor<3x4x3x2xi32>, tensor<3x1x1x1xi32>) -> tensor<3x4x3x2xi32> |
| 94 | +// CHECK: %[[SUB_CAST:.*]] = tosa.cast %[[SUB]] : (tensor<3x4x3x2xi32>) -> tensor<3x4x3x2xf32> |
| 95 | +// CHECK: %[[SCALE_TENSOR:.*]] = torch_c.to_builtin_tensor %[[SCALE]] : !torch.vtensor<[3],f32> -> tensor<3xf32> |
| 96 | +// CHECK: %[[SCALE_RESHAPED:.*]] = tosa.reshape %[[SCALE_TENSOR]], %[[QUANT_PARAM_SHAPE]] : (tensor<3xf32>, !tosa.shape<4>) -> tensor<3x1x1x1xf32> |
| 97 | +// CHECK: %[[MUL:.*]] = tosa.mul %[[SUB_CAST]], %[[SCALE_RESHAPED]], %[[MUL_SHIFT]] : (tensor<3x4x3x2xf32>, tensor<3x1x1x1xf32>, tensor<1xi8>) -> tensor<3x4x3x2xf32> |
| 98 | +// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[MUL]] |
| 99 | +func.func @dequantize.self(%arg0: !torch.vtensor<[3,4,3,2],si8>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],si8>) -> !torch.vtensor<[3,4,3,2],f32> { |
| 100 | + %int0 = torch.constant.int 0 |
| 101 | + %0 = torch.aten._make_per_channel_quantized_tensor %arg0, %arg1, %arg2, %int0 : !torch.vtensor<[3,4,3,2],si8>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si8>, !torch.int -> !torch.vtensor<[3,4,3,2],!torch.qint8> |
| 102 | + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[3,4,3,2],!torch.qint8> -> !torch.vtensor<[3,4,3,2],f32> |
| 103 | + return %1 : !torch.vtensor<[3,4,3,2],f32> |
| 104 | +} |
| 105 | + |
| 106 | + |
| 107 | +// ----- |
| 108 | +// CHECK-LABEL: func.func @quantized_conv( |
| 109 | +// CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> |
| 110 | +// CHECK: %[[IN_ZP:.*]] = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8> |
| 111 | +// CHECK: %[[CONV:.*]] = tosa.conv2d |
| 112 | +// CHECK-SAME: %[[IN_ZP]], %[[WTS_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x7x8x4xi8>, tensor<3x3x2x4xi8>, tensor<?xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x5x7x3xi32> |
| 113 | +// CHECK-NOT: torch.aten.quantize_per_tensor |
| 114 | +// CHECK-NOT: torch.aten.dequantize.self |
| 115 | +// CHECK-NOT: torch.aten._make_per_tensor_quantized_tensor |
| 116 | +// CHECK-NOT: torch.aten.dequantize.tensor |
| 117 | + |
| 118 | +func.func @quantized_conv(%arg0: !torch.vtensor<[?,4,7,8],si8>, %arg1: !torch.vtensor<[3,4,3,2],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,3,5,7],f32> { |
| 119 | + %false = torch.constant.bool false |
| 120 | + %int1 = torch.constant.int 1 |
| 121 | + %int0 = torch.constant.int 0 |
| 122 | + %float1.000000e-04 = torch.constant.float 1.000000e-04 |
| 123 | + %int3 = torch.constant.int 3 |
| 124 | + %int7 = torch.constant.int 7 |
| 125 | + %float1.000000e-02 = torch.constant.float 1.000000e-02 |
| 126 | + %int14 = torch.constant.int 14 |
| 127 | + %0 = torch.aten.quantize_per_tensor %arg2, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32> |
| 128 | + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],f32> |
| 129 | + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> |
| 130 | + %3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> |
| 131 | + %4 = torch.prim.ListConstruct : () -> !torch.list<int> |
| 132 | + %5 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float1.000000e-02, %int7 : !torch.vtensor<[?,4,7,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,4,7,8],!torch.qint8> |
| 133 | + %6 = torch.aten._make_per_tensor_quantized_tensor %arg1, %float1.000000e-02, %int3 : !torch.vtensor<[3,4,3,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,4,3,2],!torch.qint8> |
| 134 | + %7 = torch.aten.quantize_per_tensor %1, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32> |
| 135 | + %8 = torch.aten.int_repr %7 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],si32> |
| 136 | + %9 = torch.aten.convolution %5, %6, %8, %2, %3, %2, %false, %4, %int1 : !torch.vtensor<[?,4,7,8],!torch.qint8>, !torch.vtensor<[3,4,3,2],!torch.qint8>, !torch.vtensor<[?],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,3,5,7],si32> |
| 137 | + %10 = torch.aten._make_per_tensor_quantized_tensor %9, %float1.000000e-04, %int0 : !torch.vtensor<[?,3,5,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[?,3,5,7],!torch.qint32> |
| 138 | + %11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,3,5,7],!torch.qint32> -> !torch.vtensor<[?,3,5,7],f32> |
| 139 | + return %11 : !torch.vtensor<[?,3,5,7],f32> |
| 140 | +} |
0 commit comments