Skip to content

Conversation

@amrami
Copy link
Contributor

@amrami amrami commented Dec 7, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Dec 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Maya Amrami (amrami)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/171039.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+5-5)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+14-1)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..4639016fb589a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2563,6 +2563,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
     for (int64_t idx : llvm::reverse(trailingReassocs)) {
       stride = stride * SaturatedInteger::wrap(srcShape[idx]);
 
+      // Dimensions of size 1 should be skipped, because their strides are
+      // meaningless and could have any arbitrary value.
+      if (srcShape[idx - 1] == 1)
+        continue;
+
       // Both source and result stride must have the same static value. In that
       // case, we can be sure, that the dimensions are collapsible (because they
       // are contiguous).
@@ -2575,11 +2580,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
       if (strict && (stride.saturated || srcStride.saturated))
         return failure();
 
-      // Dimensions of size 1 should be skipped, because their strides are
-      // meaningless and could have any arbitrary value.
-      if (srcShape[idx - 1] == 1)
-        continue;
-
       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
         return failure();
     }
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index a90c9505a8405..cddc79f693b11 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -440,7 +440,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg4: index,
          %arg5: index,
          %arg6: index,
-         %arg7: memref<4x?x4xf32>) {
+         %arg7: memref<4x?x4xf32>,
+         %arg8: memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>>,
+         %arg9: memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>) {
+
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -489,6 +492,16 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
   %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
         : memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
+//  CHECK-SAME:     memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>>
+  %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
+//  CHECK-SAME:     memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into memref<3x288xf32, strided<[288, 1], offset: 864>>
+  %6 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
+    memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into
+    memref<3x288xf32, strided<[288, 1], offset: 864>>
   return
 }
 

@amrami amrami self-assigned this Dec 7, 2025
@amrami
Copy link
Contributor Author

amrami commented Dec 7, 2025

@krzysz00

@amrami
Copy link
Contributor Author

amrami commented Dec 7, 2025

I added a test with the case mentioned in the revert #166448 and verified that it passes

@hanhanW hanhanW requested a review from krzysz00 December 8, 2025 09:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants