Skip to content

Conversation

@rolfmorel
Copy link
Contributor

Friendlier wrapper for transform.foreach.

To facilitate that friendliness, makes it so that OpResult.owner returns the relevant OpView instead of Operation. For good measure, also changes Value.owner to return OpView instead of Operation, thereby ensuring consistency. That is, makes it is so that all op-returning .owner accessors return OpView (and thereby give access to all goodies available on registered OpViews.)

Reland of #171544 due to fixup for integration test.

@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Friendlier wrapper for transform.foreach.

To facilitate that friendliness, makes it so that OpResult.owner returns the relevant OpView instead of Operation. For good measure, also changes Value.owner to return OpView instead of Operation, thereby ensuring consistency. That is, makes it is so that all op-returning .owner accessors return OpView (and thereby give access to all goodies available on registered OpViews.)

Reland of #171544 due to fixup for integration test.


Full diff: https://github.com/llvm/llvm-project/pull/172228.diff

5 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+4-4)
  • (modified) mlir/python/mlir/dialects/memref.py (+1-2)
  • (modified) mlir/python/mlir/dialects/transform/init.py (+50)
  • (modified) mlir/test/python/dialects/transform.py (+52)
  • (modified) mlir/test/python/integration/dialects/pdl.py (+1-1)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b0de14719ab61..168c57955af07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
   static void bindDerived(ClassTy &c) {
     c.def_prop_ro(
         "owner",
-        [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+        [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
           assert(mlirOperationEqual(self.getParentOperation()->get(),
                                     mlirOpResultGetOwner(self.get())) &&
                  "expected the owner of the value in Python to match that in "
                  "the IR");
-          return self.getParentOperation().getObject();
+          return self.getParentOperation()->createOpView();
         },
         "Returns the operation that produces this result.");
     c.def_prop_ro(
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kDumpDocstring)
       .def_prop_ro(
           "owner",
-          [](PyValue &self) -> nb::object {
+          [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
             MlirValue v = self.get();
             if (mlirValueIsAOpResult(v)) {
               assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                      "expected the owner of the value in Python to match "
                      "that in "
                      "the IR");
-              return self.getParentOperation().getObject();
+              return self.getParentOperation()->createOpView();
             }
 
             if (mlirValueIsABlockArgument(v)) {
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index bc9a3a52728ad..c80a1b1a89358 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,8 +14,7 @@
 def _is_constant_int_like(i):
     return (
         isinstance(i, Value)
-        and isinstance(i.owner, Operation)
-        and isinstance(i.owner.opview, ConstantOp)
+        and isinstance(i.owner, ConstantOp)
         and _is_integer_like_type(i.type)
     )
 
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b3dd79c7dbd79..fbe4078782997 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -310,6 +310,8 @@ def __init__(
             sym_visibility=sym_visibility,
             arg_attrs=arg_attrs,
             res_attrs=res_attrs,
+            loc=loc,
+            ip=ip,
         )
         self.regions[0].blocks.append(*input_types)
 
@@ -468,6 +470,54 @@ def apply_registered_pass(
     ).result
 
 
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ForeachOp(ForeachOp):
+    def __init__(
+        self,
+        results: Sequence[Type],
+        targets: Sequence[Union[Operation, Value, OpView]],
+        *,
+        with_zip_shortest: Optional[bool] = False,
+        loc=None,
+        ip=None,
+    ):
+        targets = [_get_op_result_or_value(target) for target in targets]
+        super().__init__(
+            results_=results,
+            targets=targets,
+            with_zip_shortest=with_zip_shortest,
+            loc=loc,
+            ip=ip,
+        )
+        self.regions[0].blocks.append(*[target.type for target in targets])
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTargets(self) -> BlockArgumentList:
+        return self.regions[0].blocks[0].arguments
+
+
+def foreach(
+    results: Sequence[Type],
+    targets: Sequence[Union[Operation, Value, OpView]],
+    *,
+    with_zip_shortest: Optional[bool] = False,
+    loc=None,
+    ip=None,
+) -> Union[OpResult, OpResultList, ForeachOp]:
+    results = ForeachOp(
+        results=results,
+        targets=targets,
+        with_zip_shortest=with_zip_shortest,
+        loc=loc,
+        ip=ip,
+    ).results
+    return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index f58442d04fc66..dfcc890b83ffc 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
             options={"exclude": (symbol_a, symbol_b)},
         )
         transform.YieldOp()
+
+
+# CHECK-LABEL: TEST: testForeachOp
+@run
+def testForeachOp(module: Module):
+    # CHECK: transform.sequence
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [transform.AnyOpType.get()],
+        transform.AnyOpType.get(),
+    )
+    with InsertionPoint(sequence.body):
+        # CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
+        foreach1 = transform.ForeachOp(
+            (transform.AnyOpType.get(),), (sequence.bodyTarget,)
+        )
+        with InsertionPoint(foreach1.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_op
+            transform.yield_(foreach1.bodyTargets)
+
+        a_val = transform.get_operand(
+            transform.AnyValueType.get(), foreach1.result, [0]
+        )
+        a_param = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("a_param")
+        )
+
+        # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
+        foreach2 = transform.foreach(
+            (transform.AnyValueType.get(), transform.AnyParamType.get()),
+            (sequence.bodyTarget, a_val, a_param),
+        )
+        with InsertionPoint(foreach2.owner.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
+            transform.yield_(foreach2.owner.bodyTargets[1:3])
+
+        another_param = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("another_param")
+        )
+        params = transform.merge_handles([a_param, another_param])
+
+        # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
+        foreach3 = transform.foreach(
+            (transform.AnyOpType.get(),),
+            (foreach1.result, foreach2[1], params),
+            with_zip_shortest=True,
+        )
+        with InsertionPoint(foreach3.owner.body):
+            # CHECK: transform.yield {{.*}} : !transform.any_op
+            transform.yield_((foreach3.owner.bodyTargets[0],))
+
+        transform.yield_((foreach3,))
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index fe27dd4203a21..6a377a090fbb9 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -174,7 +174,7 @@ def add_fold(rewriter, results, values):
 
     def is_zero(value):
         op = value.owner
-        if isinstance(op, Operation):
+        if isinstance(op, OpView):
             return op.name == "myint.constant" and op.attributes["value"].value == 0
         return False
 

def is_zero(value):
op = value.owner
if isinstance(op, Operation):
if isinstance(op, OpView):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this the fail? I'm not sure why this is even an "integration" test - it doesn't require anything that's not available on the pre-commit runners

Copy link
Contributor Author

@rolfmorel rolfmorel Dec 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that was it.

Also not sure why this is an integration test.

@rolfmorel rolfmorel merged commit f12fcf0 into llvm:main Dec 14, 2025
13 checks passed
anonymouspc pushed a commit to anonymouspc/llvm that referenced this pull request Dec 15, 2025
llvm#172228)

Friendlier wrapper for transform.foreach.

To facilitate that friendliness, makes it so that OpResult.owner returns
the relevant OpView instead of Operation. For good measure, also changes
Value.owner to return OpView instead of Operation, thereby ensuring
consistency. That is, makes it is so that all op-returning .owner
accessors return OpView (and thereby give access to all goodies
available on registered OpViews.)

Reland of llvm#171544 due to fixup for integration test.
makslevental added a commit to llvm/eudsl that referenced this pull request Dec 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants