Skip to content

Commit 4cdec92

Browse files
authored
[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (#171544)
Friendlier wrapper for `transform.foreach`. To facilitate that friendliness, makes it so that `OpResult.owner` returns the relevant `OpView` instead of `Operation`. For good measure, also changes `Value.owner` to return `OpView` instead of `Operation`, thereby ensuring consistency. That is, makes it is so that all op-returning `.owner` accessors return `OpView` (and thereby give access to all goodies available on registered `OpView`s.)
1 parent 53cf22f commit 4cdec92

File tree

4 files changed

+107
-6
lines changed

4 files changed

+107
-6
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
15191519
static void bindDerived(ClassTy &c) {
15201520
c.def_prop_ro(
15211521
"owner",
1522-
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
1522+
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
15231523
assert(mlirOperationEqual(self.getParentOperation()->get(),
15241524
mlirOpResultGetOwner(self.get())) &&
15251525
"expected the owner of the value in Python to match that in "
15261526
"the IR");
1527-
return self.getParentOperation().getObject();
1527+
return self.getParentOperation()->createOpView();
15281528
},
15291529
"Returns the operation that produces this result.");
15301530
c.def_prop_ro(
@@ -4646,15 +4646,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
46464646
kDumpDocstring)
46474647
.def_prop_ro(
46484648
"owner",
4649-
[](PyValue &self) -> nb::object {
4649+
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
46504650
MlirValue v = self.get();
46514651
if (mlirValueIsAOpResult(v)) {
46524652
assert(mlirOperationEqual(self.getParentOperation()->get(),
46534653
mlirOpResultGetOwner(self.get())) &&
46544654
"expected the owner of the value in Python to match "
46554655
"that in "
46564656
"the IR");
4657-
return self.getParentOperation().getObject();
4657+
return self.getParentOperation()->createOpView();
46584658
}
46594659

46604660
if (mlirValueIsABlockArgument(v)) {

mlir/python/mlir/dialects/memref.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
def _is_constant_int_like(i):
1515
return (
1616
isinstance(i, Value)
17-
and isinstance(i.owner, Operation)
18-
and isinstance(i.owner.opview, ConstantOp)
17+
and isinstance(i.owner, ConstantOp)
1918
and _is_integer_like_type(i.type)
2019
)
2120

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def __init__(
310310
sym_visibility=sym_visibility,
311311
arg_attrs=arg_attrs,
312312
res_attrs=res_attrs,
313+
loc=loc,
314+
ip=ip,
313315
)
314316
self.regions[0].blocks.append(*input_types)
315317

@@ -468,6 +470,54 @@ def apply_registered_pass(
468470
).result
469471

470472

473+
@_ods_cext.register_operation(_Dialect, replace=True)
474+
class ForeachOp(ForeachOp):
475+
def __init__(
476+
self,
477+
results: Sequence[Type],
478+
targets: Sequence[Union[Operation, Value, OpView]],
479+
*,
480+
with_zip_shortest: Optional[bool] = False,
481+
loc=None,
482+
ip=None,
483+
):
484+
targets = [_get_op_result_or_value(target) for target in targets]
485+
super().__init__(
486+
results_=results,
487+
targets=targets,
488+
with_zip_shortest=with_zip_shortest,
489+
loc=loc,
490+
ip=ip,
491+
)
492+
self.regions[0].blocks.append(*[target.type for target in targets])
493+
494+
@property
495+
def body(self) -> Block:
496+
return self.regions[0].blocks[0]
497+
498+
@property
499+
def bodyTargets(self) -> BlockArgumentList:
500+
return self.regions[0].blocks[0].arguments
501+
502+
503+
def foreach(
504+
results: Sequence[Type],
505+
targets: Sequence[Union[Operation, Value, OpView]],
506+
*,
507+
with_zip_shortest: Optional[bool] = False,
508+
loc=None,
509+
ip=None,
510+
) -> Union[OpResult, OpResultList, ForeachOp]:
511+
results = ForeachOp(
512+
results=results,
513+
targets=targets,
514+
with_zip_shortest=with_zip_shortest,
515+
loc=loc,
516+
ip=ip,
517+
).results
518+
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
519+
520+
471521
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
472522

473523

mlir/test/python/dialects/transform.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
401401
options={"exclude": (symbol_a, symbol_b)},
402402
)
403403
transform.YieldOp()
404+
405+
406+
# CHECK-LABEL: TEST: testForeachOp
407+
@run
408+
def testForeachOp(module: Module):
409+
# CHECK: transform.sequence
410+
sequence = transform.SequenceOp(
411+
transform.FailurePropagationMode.Propagate,
412+
[transform.AnyOpType.get()],
413+
transform.AnyOpType.get(),
414+
)
415+
with InsertionPoint(sequence.body):
416+
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
417+
foreach1 = transform.ForeachOp(
418+
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
419+
)
420+
with InsertionPoint(foreach1.body):
421+
# CHECK: transform.yield {{.*}} : !transform.any_op
422+
transform.yield_(foreach1.bodyTargets)
423+
424+
a_val = transform.get_operand(
425+
transform.AnyValueType.get(), foreach1.result, [0]
426+
)
427+
a_param = transform.param_constant(
428+
transform.AnyParamType.get(), StringAttr.get("a_param")
429+
)
430+
431+
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
432+
foreach2 = transform.foreach(
433+
(transform.AnyValueType.get(), transform.AnyParamType.get()),
434+
(sequence.bodyTarget, a_val, a_param),
435+
)
436+
with InsertionPoint(foreach2.owner.body):
437+
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
438+
transform.yield_(foreach2.owner.bodyTargets[1:3])
439+
440+
another_param = transform.param_constant(
441+
transform.AnyParamType.get(), StringAttr.get("another_param")
442+
)
443+
params = transform.merge_handles([a_param, another_param])
444+
445+
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
446+
foreach3 = transform.foreach(
447+
(transform.AnyOpType.get(),),
448+
(foreach1.result, foreach2[1], params),
449+
with_zip_shortest=True,
450+
)
451+
with InsertionPoint(foreach3.owner.body):
452+
# CHECK: transform.yield {{.*}} : !transform.any_op
453+
transform.yield_((foreach3.owner.bodyTargets[0],))
454+
455+
transform.yield_((foreach3,))

0 commit comments

Comments
 (0)