Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
return self.getParentOperation().getObject();
return self.getParentOperation()->createOpView();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
Expand Down Expand Up @@ -4646,15 +4646,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
[](PyValue &self) -> nb::object {
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match "
"that in "
"the IR");
return self.getParentOperation().getObject();
return self.getParentOperation()->createOpView();
}

if (mlirValueIsABlockArgument(v)) {
Expand Down
3 changes: 1 addition & 2 deletions mlir/python/mlir/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and isinstance(i.owner, ConstantOp)
and _is_integer_like_type(i.type)
)

Expand Down
50 changes: 50 additions & 0 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def __init__(
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*input_types)

Expand Down Expand Up @@ -468,6 +470,54 @@ def apply_registered_pass(
).result


@_ods_cext.register_operation(_Dialect, replace=True)
class ForeachOp(ForeachOp):
def __init__(
self,
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
):
targets = [_get_op_result_or_value(target) for target in targets]
super().__init__(
results_=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*[target.type for target in targets])

@property
def body(self) -> Block:
return self.regions[0].blocks[0]

@property
def bodyTargets(self) -> BlockArgumentList:
return self.regions[0].blocks[0].arguments


def foreach(
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
) -> Union[OpResult, OpResultList, ForeachOp]:
results = ForeachOp(
results=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
).results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)


AnyOpTypeT = NewType("AnyOpType", AnyOpType)


Expand Down
52 changes: 52 additions & 0 deletions mlir/test/python/dialects/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()


# CHECK-LABEL: TEST: testForeachOp
@run
def testForeachOp(module: Module):
# CHECK: transform.sequence
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
foreach1 = transform.ForeachOp(
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
)
with InsertionPoint(foreach1.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_(foreach1.bodyTargets)

a_val = transform.get_operand(
transform.AnyValueType.get(), foreach1.result, [0]
)
a_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("a_param")
)

# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
foreach2 = transform.foreach(
(transform.AnyValueType.get(), transform.AnyParamType.get()),
(sequence.bodyTarget, a_val, a_param),
)
with InsertionPoint(foreach2.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
transform.yield_(foreach2.owner.bodyTargets[1:3])

another_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("another_param")
)
params = transform.merge_handles([a_param, another_param])

# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
foreach3 = transform.foreach(
(transform.AnyOpType.get(),),
(foreach1.result, foreach2[1], params),
with_zip_shortest=True,
)
with InsertionPoint(foreach3.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_((foreach3.owner.bodyTargets[0],))

transform.yield_((foreach3,))
2 changes: 1 addition & 1 deletion mlir/test/python/integration/dialects/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def add_fold(rewriter, results, values):

def is_zero(value):
op = value.owner
if isinstance(op, Operation):
if isinstance(op, OpView):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this the fail? I'm not sure why this is even an "integration" test - it doesn't require anything that's not available on the pre-commit runners

Copy link
Contributor Author

@rolfmorel rolfmorel Dec 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that was it.

Also not sure why this is an integration test.

return op.name == "myint.constant" and op.attributes["value"].value == 0
return False

Expand Down