Skip to content

Commit 837e498

Browse files
authored
[TOSA] Conv3d legalization (#4383)
- Extend Torch to TOSA conversion with full aten.conv3d coverage, mapping padding/stride/dilation configs and bias handling into canonical TOSA conv ops. - Refactor TorchToTosa lowering utilities to share more code between conv variants. - Update conversion tests plus PT1 xfail lists to reflect the newly supported 3D convolution paths. Change-Id: Iaf1261e121dec1ddb84b814e9058bb7ebadd4de7 Signed-off-by: Cathal Corbett <[email protected]>
1 parent 42c3c29 commit 837e498

File tree

6 files changed

+887
-407
lines changed

6 files changed

+887
-407
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
105105
FailureOr<Value> getConvBiasForNoneType(Operation *op,
106106
PatternRewriter &rewriter,
107107
Type inputElemTy, Type outputElemTy,
108-
ArrayRef<int64_t> weightShape);
108+
int64_t numOutputChannels);
109109

110110
// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later
111111
// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as

0 commit comments

Comments
 (0)