diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1035d7cb46e6e..4639016fb589a 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2563,6 +2563,11 @@ computeCollapsedLayoutMap(MemRefType srcType, for (int64_t idx : llvm::reverse(trailingReassocs)) { stride = stride * SaturatedInteger::wrap(srcShape[idx]); + // Dimensions of size 1 should be skipped, because their strides are + // meaningless and could have any arbitrary value. + if (srcShape[idx - 1] == 1) + continue; + // Both source and result stride must have the same static value. In that // case, we can be sure, that the dimensions are collapsible (because they // are contiguous). @@ -2575,11 +2580,6 @@ computeCollapsedLayoutMap(MemRefType srcType, if (strict && (stride.saturated || srcStride.saturated)) return failure(); - // Dimensions of size 1 should be skipped, because their strides are - // meaningless and could have any arbitrary value. - if (srcShape[idx - 1] == 1) - continue; - if (!stride.saturated && !srcStride.saturated && stride != srcStride) return failure(); } diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index a90c9505a8405..cddc79f693b11 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -440,7 +440,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, %arg4: index, %arg5: index, %arg6: index, - %arg7: memref<4x?x4xf32>) { + %arg7: memref<4x?x4xf32>, + %arg8: memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>>, + %arg9: memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>) { + // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : @@ -489,6 +492,16 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2] : memref<4x?x4xf32> into memref<2x2x?x2x2xf32> + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]] +// CHECK-SAME: memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>> + %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>> + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]] +// CHECK-SAME: memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into memref<3x288xf32, strided<[288, 1], offset: 864>> + %6 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] : + memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into + memref<3x288xf32, strided<[288, 1], offset: 864>> return }