diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2e0c2b895216f..da33945f42913 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue { static void bindDerived(ClassTy &c) { c.def_prop_ro( "owner", - [](PyOpResult &self) -> nb::typed { + [](PyOpResult &self) -> nb::typed { assert(mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in " "the IR"); - return self.getParentOperation().getObject(); + return self.getParentOperation()->createOpView(); }, "Returns the operation that produces this result."); c.def_prop_ro( @@ -4638,7 +4638,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { kDumpDocstring) .def_prop_ro( "owner", - [](PyValue &self) -> nb::object { + [](PyValue &self) -> nb::typed { MlirValue v = self.get(); if (mlirValueIsAOpResult(v)) { assert(mlirOperationEqual(self.getParentOperation()->get(), @@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { "expected the owner of the value in Python to match " "that in " "the IR"); - return self.getParentOperation().getObject(); + return self.getParentOperation()->createOpView(); } if (mlirValueIsABlockArgument(v)) { diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index bc9a3a52728ad..c80a1b1a89358 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -14,8 +14,7 @@ def _is_constant_int_like(i): return ( isinstance(i, Value) - and isinstance(i.owner, Operation) - and isinstance(i.owner.opview, ConstantOp) + and isinstance(i.owner, ConstantOp) and _is_integer_like_type(i.type) ) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index b3dd79c7dbd79..fbe4078782997 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -310,6 +310,8 @@ def __init__( sym_visibility=sym_visibility, arg_attrs=arg_attrs, res_attrs=res_attrs, + loc=loc, + ip=ip, ) self.regions[0].blocks.append(*input_types) @@ -468,6 +470,54 @@ def apply_registered_pass( ).result +@_ods_cext.register_operation(_Dialect, replace=True) +class ForeachOp(ForeachOp): + def __init__( + self, + results: Sequence[Type], + targets: Sequence[Union[Operation, Value, OpView]], + *, + with_zip_shortest: Optional[bool] = False, + loc=None, + ip=None, + ): + targets = [_get_op_result_or_value(target) for target in targets] + super().__init__( + results_=results, + targets=targets, + with_zip_shortest=with_zip_shortest, + loc=loc, + ip=ip, + ) + self.regions[0].blocks.append(*[target.type for target in targets]) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTargets(self) -> BlockArgumentList: + return self.regions[0].blocks[0].arguments + + +def foreach( + results: Sequence[Type], + targets: Sequence[Union[Operation, Value, OpView]], + *, + with_zip_shortest: Optional[bool] = False, + loc=None, + ip=None, +) -> Union[OpResult, OpResultList, ForeachOp]: + results = ForeachOp( + results=results, + targets=targets, + with_zip_shortest=with_zip_shortest, + loc=loc, + ip=ip, + ).results + return results if len(results) > 1 else (results[0] if len(results) == 1 else op) + + AnyOpTypeT = NewType("AnyOpType", AnyOpType) diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index f58442d04fc66..dfcc890b83ffc 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module): options={"exclude": (symbol_a, symbol_b)}, ) transform.YieldOp() + + +# CHECK-LABEL: TEST: testForeachOp +@run +def testForeachOp(module: Module): + # CHECK: transform.sequence + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [transform.AnyOpType.get()], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + # CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op + foreach1 = transform.ForeachOp( + (transform.AnyOpType.get(),), (sequence.bodyTarget,) + ) + with InsertionPoint(foreach1.body): + # CHECK: transform.yield {{.*}} : !transform.any_op + transform.yield_(foreach1.bodyTargets) + + a_val = transform.get_operand( + transform.AnyValueType.get(), foreach1.result, [0] + ) + a_param = transform.param_constant( + transform.AnyParamType.get(), StringAttr.get("a_param") + ) + + # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param + foreach2 = transform.foreach( + (transform.AnyValueType.get(), transform.AnyParamType.get()), + (sequence.bodyTarget, a_val, a_param), + ) + with InsertionPoint(foreach2.owner.body): + # CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param + transform.yield_(foreach2.owner.bodyTargets[1:3]) + + another_param = transform.param_constant( + transform.AnyParamType.get(), StringAttr.get("another_param") + ) + params = transform.merge_handles([a_param, another_param]) + + # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op + foreach3 = transform.foreach( + (transform.AnyOpType.get(),), + (foreach1.result, foreach2[1], params), + with_zip_shortest=True, + ) + with InsertionPoint(foreach3.owner.body): + # CHECK: transform.yield {{.*}} : !transform.any_op + transform.yield_((foreach3.owner.bodyTargets[0],)) + + transform.yield_((foreach3,))