Skip to content

Commit f12fcf0

Browse files
authored
[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (#172228)
Friendlier wrapper for transform.foreach. To facilitate that friendliness, makes it so that OpResult.owner returns the relevant OpView instead of Operation. For good measure, also changes Value.owner to return OpView instead of Operation, thereby ensuring consistency. That is, makes it is so that all op-returning .owner accessors return OpView (and thereby give access to all goodies available on registered OpViews.) Reland of #171544 due to fixup for integration test.
1 parent 423919d commit f12fcf0

File tree

5 files changed

+108
-7
lines changed

5 files changed

+108
-7
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
15191519
static void bindDerived(ClassTy &c) {
15201520
c.def_prop_ro(
15211521
"owner",
1522-
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
1522+
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
15231523
assert(mlirOperationEqual(self.getParentOperation()->get(),
15241524
mlirOpResultGetOwner(self.get())) &&
15251525
"expected the owner of the value in Python to match that in "
15261526
"the IR");
1527-
return self.getParentOperation().getObject();
1527+
return self.getParentOperation()->createOpView();
15281528
},
15291529
"Returns the operation that produces this result.");
15301530
c.def_prop_ro(
@@ -4646,15 +4646,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
46464646
kDumpDocstring)
46474647
.def_prop_ro(
46484648
"owner",
4649-
[](PyValue &self) -> nb::object {
4649+
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
46504650
MlirValue v = self.get();
46514651
if (mlirValueIsAOpResult(v)) {
46524652
assert(mlirOperationEqual(self.getParentOperation()->get(),
46534653
mlirOpResultGetOwner(self.get())) &&
46544654
"expected the owner of the value in Python to match "
46554655
"that in "
46564656
"the IR");
4657-
return self.getParentOperation().getObject();
4657+
return self.getParentOperation()->createOpView();
46584658
}
46594659

46604660
if (mlirValueIsABlockArgument(v)) {

mlir/python/mlir/dialects/memref.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
def _is_constant_int_like(i):
1515
return (
1616
isinstance(i, Value)
17-
and isinstance(i.owner, Operation)
18-
and isinstance(i.owner.opview, ConstantOp)
17+
and isinstance(i.owner, ConstantOp)
1918
and _is_integer_like_type(i.type)
2019
)
2120

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def __init__(
310310
sym_visibility=sym_visibility,
311311
arg_attrs=arg_attrs,
312312
res_attrs=res_attrs,
313+
loc=loc,
314+
ip=ip,
313315
)
314316
self.regions[0].blocks.append(*input_types)
315317

@@ -468,6 +470,54 @@ def apply_registered_pass(
468470
).result
469471

470472

473+
@_ods_cext.register_operation(_Dialect, replace=True)
474+
class ForeachOp(ForeachOp):
475+
def __init__(
476+
self,
477+
results: Sequence[Type],
478+
targets: Sequence[Union[Operation, Value, OpView]],
479+
*,
480+
with_zip_shortest: Optional[bool] = False,
481+
loc=None,
482+
ip=None,
483+
):
484+
targets = [_get_op_result_or_value(target) for target in targets]
485+
super().__init__(
486+
results_=results,
487+
targets=targets,
488+
with_zip_shortest=with_zip_shortest,
489+
loc=loc,
490+
ip=ip,
491+
)
492+
self.regions[0].blocks.append(*[target.type for target in targets])
493+
494+
@property
495+
def body(self) -> Block:
496+
return self.regions[0].blocks[0]
497+
498+
@property
499+
def bodyTargets(self) -> BlockArgumentList:
500+
return self.regions[0].blocks[0].arguments
501+
502+
503+
def foreach(
504+
results: Sequence[Type],
505+
targets: Sequence[Union[Operation, Value, OpView]],
506+
*,
507+
with_zip_shortest: Optional[bool] = False,
508+
loc=None,
509+
ip=None,
510+
) -> Union[OpResult, OpResultList, ForeachOp]:
511+
results = ForeachOp(
512+
results=results,
513+
targets=targets,
514+
with_zip_shortest=with_zip_shortest,
515+
loc=loc,
516+
ip=ip,
517+
).results
518+
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
519+
520+
471521
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
472522

473523

mlir/test/python/dialects/transform.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
401401
options={"exclude": (symbol_a, symbol_b)},
402402
)
403403
transform.YieldOp()
404+
405+
406+
# CHECK-LABEL: TEST: testForeachOp
407+
@run
408+
def testForeachOp(module: Module):
409+
# CHECK: transform.sequence
410+
sequence = transform.SequenceOp(
411+
transform.FailurePropagationMode.Propagate,
412+
[transform.AnyOpType.get()],
413+
transform.AnyOpType.get(),
414+
)
415+
with InsertionPoint(sequence.body):
416+
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
417+
foreach1 = transform.ForeachOp(
418+
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
419+
)
420+
with InsertionPoint(foreach1.body):
421+
# CHECK: transform.yield {{.*}} : !transform.any_op
422+
transform.yield_(foreach1.bodyTargets)
423+
424+
a_val = transform.get_operand(
425+
transform.AnyValueType.get(), foreach1.result, [0]
426+
)
427+
a_param = transform.param_constant(
428+
transform.AnyParamType.get(), StringAttr.get("a_param")
429+
)
430+
431+
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
432+
foreach2 = transform.foreach(
433+
(transform.AnyValueType.get(), transform.AnyParamType.get()),
434+
(sequence.bodyTarget, a_val, a_param),
435+
)
436+
with InsertionPoint(foreach2.owner.body):
437+
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
438+
transform.yield_(foreach2.owner.bodyTargets[1:3])
439+
440+
another_param = transform.param_constant(
441+
transform.AnyParamType.get(), StringAttr.get("another_param")
442+
)
443+
params = transform.merge_handles([a_param, another_param])
444+
445+
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
446+
foreach3 = transform.foreach(
447+
(transform.AnyOpType.get(),),
448+
(foreach1.result, foreach2[1], params),
449+
with_zip_shortest=True,
450+
)
451+
with InsertionPoint(foreach3.owner.body):
452+
# CHECK: transform.yield {{.*}} : !transform.any_op
453+
transform.yield_((foreach3.owner.bodyTargets[0],))
454+
455+
transform.yield_((foreach3,))

mlir/test/python/integration/dialects/pdl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def add_fold(rewriter, results, values):
174174

175175
def is_zero(value):
176176
op = value.owner
177-
if isinstance(op, Operation):
177+
if isinstance(op, OpView):
178178
return op.name == "myint.constant" and op.attributes["value"].value == 0
179179
return False
180180

0 commit comments

Comments
 (0)