Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
return self.getParentOperation().getObject();
return self.getParentOperation()->createOpView();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
Expand Down Expand Up @@ -4638,15 +4638,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
[](PyValue &self) -> nb::object {
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match "
"that in "
"the IR");
return self.getParentOperation().getObject();
return self.getParentOperation()->createOpView();
}

if (mlirValueIsABlockArgument(v)) {
Expand Down
3 changes: 1 addition & 2 deletions mlir/python/mlir/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and isinstance(i.owner, ConstantOp)
and _is_integer_like_type(i.type)
)

Expand Down
50 changes: 50 additions & 0 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def __init__(
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*input_types)

Expand Down Expand Up @@ -468,6 +470,54 @@ def apply_registered_pass(
).result


@_ods_cext.register_operation(_Dialect, replace=True)
class ForeachOp(ForeachOp):
def __init__(
self,
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
):
targets = [_get_op_result_or_value(target) for target in targets]
super().__init__(
results_=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*[target.type for target in targets])

@property
def body(self) -> Block:
return self.regions[0].blocks[0]

@property
def bodyTargets(self) -> BlockArgumentList:
return self.regions[0].blocks[0].arguments


def foreach(
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
) -> Union[OpResult, OpResultList, ForeachOp]:
results = ForeachOp(
results=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
).results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)


AnyOpTypeT = NewType("AnyOpType", AnyOpType)


Expand Down
52 changes: 52 additions & 0 deletions mlir/test/python/dialects/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()


# CHECK-LABEL: TEST: testForeachOp
@run
def testForeachOp(module: Module):
# CHECK: transform.sequence
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
foreach1 = transform.ForeachOp(
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
)
with InsertionPoint(foreach1.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_(foreach1.bodyTargets)

a_val = transform.get_operand(
transform.AnyValueType.get(), foreach1.result, [0]
)
a_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("a_param")
)

# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
foreach2 = transform.foreach(
(transform.AnyValueType.get(), transform.AnyParamType.get()),
(sequence.bodyTarget, a_val, a_param),
)
with InsertionPoint(foreach2.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
transform.yield_(foreach2.owner.bodyTargets[1:3])

another_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("another_param")
)
params = transform.merge_handles([a_param, another_param])

# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
foreach3 = transform.foreach(
(transform.AnyOpType.get(),),
(foreach1.result, foreach2[1], params),
with_zip_shortest=True,
)
with InsertionPoint(foreach3.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_((foreach3.owner.bodyTargets[0],))

transform.yield_((foreach3,))