@@ -112,6 +112,8 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
112112 .Default ([](auto op) { return op->emitError (" unsupported operation" ); });
113113}
114114
115+ static bool shouldBeInlined (Operation *op);
116+
115117namespace {
116118// / Emitter that uses dialect specific emitters to emit C++ code.
117119struct CppEmitter {
@@ -255,24 +257,19 @@ struct CppEmitter {
255257 }
256258
257259 // / Is expression currently being emitted.
258- bool isEmittingExpression () { return emittedExpression ; }
260+ bool isEmittingExpression () { return !emittedExpressionPrecedence. empty () ; }
259261
260262 // / Determine whether given value is part of the expression potentially being
261263 // / emitted.
262264 bool isPartOfCurrentExpression (Value value) {
263- if (!emittedExpression)
264- return false ;
265265 Operation *def = value.getDefiningOp ();
266- if (!def)
267- return false ;
268- return isPartOfCurrentExpression (def);
266+ return def ? isPartOfCurrentExpression (def) : false ;
269267 }
270268
271269 // / Determine whether given operation is part of the expression potentially
272270 // / being emitted.
273271 bool isPartOfCurrentExpression (Operation *def) {
274- auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp ());
275- return operandExpression && operandExpression == emittedExpression;
272+ return isEmittingExpression () && shouldBeInlined (def);
276273 };
277274
278275 // Resets the value counter to 0.
@@ -319,7 +316,6 @@ struct CppEmitter {
319316 unsigned int valueCount{0 };
320317
321318 // / State of the current expression being emitted.
322- ExpressionOp emittedExpression;
323319 SmallVector<int > emittedExpressionPrecedence;
324320
325321 void pushExpressionPrecedence (int precedence) {
@@ -342,12 +338,22 @@ static bool hasDeferredEmission(Operation *op) {
342338 emitc::GetFieldOp>(op);
343339}
344340
345- // / Determine whether expression \p expressionOp should be emitted inline, i.e.
341+ // / Determine whether operation \p op should be emitted inline, i.e.
346342// / as part of its user. This function recommends inlining of any expressions
347343// / that can be inlined unless it is used by another expression, under the
348344// / assumption that any expression fusion/re-materialization was taken care of
349345// / by transformations run by the backend.
350- static bool shouldBeInlined (ExpressionOp expressionOp) {
346+ static bool shouldBeInlined (Operation *op) {
347+ // CExpression operations are inlined if and only if they reside within an
348+ // ExpressionOp.
349+ if (isa<CExpressionInterface>(op))
350+ return isa<ExpressionOp>(op->getParentOp ());
351+
352+ // Only other inlinable operation is ExpressionOp itself.
353+ ExpressionOp expressionOp = dyn_cast<ExpressionOp>(op);
354+ if (!expressionOp)
355+ return false ;
356+
351357 // Do not inline if expression is marked as such.
352358 if (expressionOp.getDoNotInline ())
353359 return false ;
@@ -1585,7 +1591,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
15851591 " Expected precedence stack to be empty" );
15861592 Operation *rootOp = expressionOp.getRootOp ();
15871593
1588- emittedExpression = expressionOp;
15891594 FailureOr<int > precedence = getOperatorPrecedence (rootOp);
15901595 if (failed (precedence))
15911596 return failure ();
@@ -1597,7 +1602,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
15971602 popExpressionPrecedence ();
15981603 assert (emittedExpressionPrecedence.empty () &&
15991604 " Expected precedence stack to be empty" );
1600- emittedExpression = nullptr ;
16011605
16021606 return success ();
16031607}
@@ -1638,14 +1642,8 @@ LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) {
16381642 // If this operand is a block argument of an expression, emit instead the
16391643 // matching expression parameter.
16401644 Operation *argOp = arg.getParentBlock ()->getParentOp ();
1641- if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) {
1642- // This scenario is only expected when one of the operations within the
1643- // expression being emitted references one of the expression's block
1644- // arguments.
1645- assert (expressionOp == emittedExpression &&
1646- " Expected expression being emitted" );
1647- value = expressionOp->getOperand (arg.getArgNumber ());
1648- }
1645+ if (auto expressionOp = dyn_cast<ExpressionOp>(argOp))
1646+ return emitOperand (expressionOp->getOperand (arg.getArgNumber ()));
16491647 }
16501648
16511649 os << getOrCreateName (value);
0 commit comments