-
Notifications
You must be signed in to change notification settings - Fork 15.5k
Revert "[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews" #172225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…r OpView…" This reverts commit 4cdec92.
|
@llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesReverts llvm/llvm-project#171544 ; bots are broken. Full diff: https://github.com/llvm/llvm-project/pull/172225.diff 4 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..b0de14719ab61 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1519,12 +1519,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
- [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
+ [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
- return self.getParentOperation()->createOpView();
+ return self.getParentOperation().getObject();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
- [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
+ [](PyValue &self) -> nb::object {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"expected the owner of the value in Python to match "
"that in "
"the IR");
- return self.getParentOperation()->createOpView();
+ return self.getParentOperation().getObject();
}
if (mlirValueIsABlockArgument(v)) {
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index c80a1b1a89358..bc9a3a52728ad 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -14,7 +14,8 @@
def _is_constant_int_like(i):
return (
isinstance(i, Value)
- and isinstance(i.owner, ConstantOp)
+ and isinstance(i.owner, Operation)
+ and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index fbe4078782997..b3dd79c7dbd79 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -310,8 +310,6 @@ def __init__(
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
- loc=loc,
- ip=ip,
)
self.regions[0].blocks.append(*input_types)
@@ -470,54 +468,6 @@ def apply_registered_pass(
).result
-@_ods_cext.register_operation(_Dialect, replace=True)
-class ForeachOp(ForeachOp):
- def __init__(
- self,
- results: Sequence[Type],
- targets: Sequence[Union[Operation, Value, OpView]],
- *,
- with_zip_shortest: Optional[bool] = False,
- loc=None,
- ip=None,
- ):
- targets = [_get_op_result_or_value(target) for target in targets]
- super().__init__(
- results_=results,
- targets=targets,
- with_zip_shortest=with_zip_shortest,
- loc=loc,
- ip=ip,
- )
- self.regions[0].blocks.append(*[target.type for target in targets])
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTargets(self) -> BlockArgumentList:
- return self.regions[0].blocks[0].arguments
-
-
-def foreach(
- results: Sequence[Type],
- targets: Sequence[Union[Operation, Value, OpView]],
- *,
- with_zip_shortest: Optional[bool] = False,
- loc=None,
- ip=None,
-) -> Union[OpResult, OpResultList, ForeachOp]:
- results = ForeachOp(
- results=results,
- targets=targets,
- with_zip_shortest=with_zip_shortest,
- loc=loc,
- ip=ip,
- ).results
- return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
-
-
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index dfcc890b83ffc..f58442d04fc66 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -401,55 +401,3 @@ def testApplyRegisteredPassOp(module: Module):
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()
-
-
-# CHECK-LABEL: TEST: testForeachOp
-@run
-def testForeachOp(module: Module):
- # CHECK: transform.sequence
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [transform.AnyOpType.get()],
- transform.AnyOpType.get(),
- )
- with InsertionPoint(sequence.body):
- # CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
- foreach1 = transform.ForeachOp(
- (transform.AnyOpType.get(),), (sequence.bodyTarget,)
- )
- with InsertionPoint(foreach1.body):
- # CHECK: transform.yield {{.*}} : !transform.any_op
- transform.yield_(foreach1.bodyTargets)
-
- a_val = transform.get_operand(
- transform.AnyValueType.get(), foreach1.result, [0]
- )
- a_param = transform.param_constant(
- transform.AnyParamType.get(), StringAttr.get("a_param")
- )
-
- # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
- foreach2 = transform.foreach(
- (transform.AnyValueType.get(), transform.AnyParamType.get()),
- (sequence.bodyTarget, a_val, a_param),
- )
- with InsertionPoint(foreach2.owner.body):
- # CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
- transform.yield_(foreach2.owner.bodyTargets[1:3])
-
- another_param = transform.param_constant(
- transform.AnyParamType.get(), StringAttr.get("another_param")
- )
- params = transform.merge_handles([a_param, another_param])
-
- # CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
- foreach3 = transform.foreach(
- (transform.AnyOpType.get(),),
- (foreach1.result, foreach2[1], params),
- with_zip_shortest=True,
- )
- with InsertionPoint(foreach3.owner.body):
- # CHECK: transform.yield {{.*}} : !transform.any_op
- transform.yield_((foreach3.owner.bodyTargets[0],))
-
- transform.yield_((foreach3,))
|
|
👍 I working on a fix. |
…er OpViews" (llvm#172225) This reverts commit b9fe653.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/18025 Here is the relevant piece of the build log for the reference |
…r OpViews" (llvm#172225) Reverts llvm#171544 ; bots are broken.
Reverts #171544 ; bots are broken.